diff --git a/bionemo-recipes/models/llama3/README.md b/bionemo-recipes/models/llama3/README.md
index ca0417c977..ef62f13b7e 100644
--- a/bionemo-recipes/models/llama3/README.md
+++ b/bionemo-recipes/models/llama3/README.md
@@ -19,7 +19,7 @@ The Llama-3 implementation natively supports the following TransformerEngine-pro
| **Export to HuggingFace checkpoints** | ✅ Supported |
| **KV-cache inference** | ✅ Supported (including beam search) |
| **Context Parallelism** | ✅ Supported |
-| **Tensor Parallelism** | 🚧 Under development |
+| **Tensor Parallelism** | ✅ Supported |
Refer to [BioNeMo Recipes](../../recipes/llama3_native_te/README.md) for more details on how to use these features to accelerate model
training and inference with native PyTorch training loops.
diff --git a/bionemo-recipes/models/llama3/modeling_llama_te.py b/bionemo-recipes/models/llama3/modeling_llama_te.py
index c88a8f9154..9bbedf32aa 100644
--- a/bionemo-recipes/models/llama3/modeling_llama_te.py
+++ b/bionemo-recipes/models/llama3/modeling_llama_te.py
@@ -25,6 +25,8 @@
import transformer_engine.common.recipe
import transformer_engine.pytorch
import transformers
+from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
+from torch.distributed.tensor.placement_types import Replicate
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention.inference import PagedKVCacheManager
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
@@ -52,6 +54,9 @@ class NVLlamaConfig(LlamaConfig):
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
attn_input_format: str = "thd"
self_attn_mask_type: str = "padding_causal"
+ tensor_parallel: bool = False
+ sequence_parallel: bool = False
+ tp_size: int = 1
def __init__(
self,
@@ -142,6 +147,8 @@ def __init__(
config: LlamaConfig,
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ nvte_tp_mesh: torch.distributed.DeviceMesh | None = None,
+ nvte_weight_mesh: torch.distributed.DeviceMesh | None = None,
):
"""Initialize the NVLlama model.
@@ -149,6 +156,8 @@ def __init__(
config: The configuration of the model.
fp8_recipe: The FP8 recipe for the model.
fp4_recipe: The FP4 recipe for the model.
+ nvte_tp_mesh: TP DeviceMesh for the model.
+ nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
"""
super().__init__(config)
self.config = config
@@ -156,6 +165,8 @@ def __init__(
self.vocab_size = config.vocab_size
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
+ self.tp_mesh = nvte_tp_mesh
+ self.weight_mesh = nvte_weight_mesh
if self.config.layer_precision is None:
if fp8_recipe is not None and fp4_recipe is not None:
@@ -174,6 +185,27 @@ def __init__(
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype)
+ # Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP.
+ if config.tensor_parallel:
+ assert self.tp_mesh is not None, "[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh."
+ assert self.tp_mesh.size() == config.tp_size, (
+ f"[NVLlamaModel] DeviceMesh TP size ({self.tp_mesh.size()}) "
+ f"does not match configured TP size ({config.tp_size}).",
+ )
+ # NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding
+ # during HuggingFace post-init, this will automatically convert the TELinear head
+ # weight into a DTensor with the correct sharding placements prior to FSDP2
+ # fully_shard(), and no need to call TELinear.set_device_mesh().
+ parallelize_module(
+ self.embed_tokens,
+ self.tp_mesh,
+ # Un-sharded output activations for compatible input to TETransformer.
+ # NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1)
+ # RowwiseParallel doesn't support output_layouts=Replicate() with
+ # torch.compile: https://github.com/pytorch/torchtitan/issues/534
+ ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
+ )
+
def _init_method(x):
torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)
@@ -201,6 +233,11 @@ def _init_method(x):
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=_init_method,
output_layer_init_method=_init_method,
+ set_parallel_mode=config.tensor_parallel,
+ sequence_parallel=config.sequence_parallel,
+ tp_size=config.tp_size,
+ tp_mesh=self.tp_mesh,
+ weight_mesh=self.weight_mesh,
)
]
@@ -211,6 +248,8 @@ def _init_method(x):
dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
)
+ # Norm modules are non-Base TransformerEngine modules that require a manual call for TP.
+ self.norm.set_device_mesh(tp_mesh=self.tp_mesh, weight_mesh=self.weight_mesh)
# We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original
# LlamaRotaryEmbedding.
@@ -387,6 +426,8 @@ def __init__(
config,
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ nvte_tp_mesh: torch.distributed.DeviceMesh | None = None,
+ nvte_weight_mesh: torch.distributed.DeviceMesh | None = None,
):
"""Initialize the NVLlamaForCausalLM model.
@@ -394,10 +435,21 @@ def __init__(
config: The configuration of the model.
fp8_recipe: The FP8 recipe for the model.
fp4_recipe: The FP4 recipe for the model.
+ nvte_tp_mesh: TP DeviceMesh for the model.
+ nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
"""
super().__init__(config)
- self.model = NVLlamaModel(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+ self.model = NVLlamaModel(
+ config,
+ fp8_recipe=fp8_recipe,
+ fp4_recipe=fp4_recipe,
+ nvte_tp_mesh=nvte_tp_mesh,
+ nvte_weight_mesh=nvte_weight_mesh,
+ )
+ self.config = config
self.vocab_size = config.vocab_size
+ self.tp_mesh = nvte_tp_mesh
+ self.weight_mesh = nvte_weight_mesh
with transformer_engine.pytorch.quantized_model_init(enabled=False):
self.lm_head = transformer_engine.pytorch.Linear(
config.hidden_size,
@@ -406,9 +458,25 @@ def __init__(
params_dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
+ parallel_mode="row" if config.tensor_parallel else None,
+ # This scatters your output, not ever needed for final layer.
+ # Will all-reduce the output instead, as required.
+ sequence_parallel=False,
+ tp_size=config.tp_size,
)
-
- # Initialize weights and apply final processing
+ if config.tensor_parallel:
+ if config.tie_word_embeddings:
+ # Head weights have already been tied to the embedding weights.
+ # Just set the tensor parallel group for TE.
+ # No parameter quantization either, so no need for weight_mesh.
+ self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())
+ else:
+ # Head weights are not tied to the embedding weights. Need to
+ # wrap the LM head weight as a DTensor with TE.
+ # No parameter quantization either, so no need for weight_mesh.
+ self.lm_head.set_device_mesh(tp_mesh=self.tp_mesh)
+
+ # Initialize weights and apply final processing. Ties weights.
self.post_init()
def forward(
@@ -461,6 +529,14 @@ def forward(
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ if self.config.tensor_parallel:
+ # If using TP, shard your activation across the TP group,
+ # to support row-wise tensor parallelism in the LM head.
+ # Use ... to support both BSHD (3D) and THD (2D) hidden states.
+ tp_rank = self.tp_mesh.get_local_rank()
+ tp_stride = hidden_states.shape[-1] // self.config.tp_size
+ hidden_states = hidden_states[..., tp_rank * tp_stride : (tp_rank + 1) * tp_stride]
+
with transformer_engine.pytorch.autocast(enabled=False):
if hidden_states.ndim == 3:
logits = self.lm_head(hidden_states[:, slice_indices, :])
diff --git a/bionemo-recipes/models/llama3/requirements.txt b/bionemo-recipes/models/llama3/requirements.txt
index ec6a547cb8..a16bb00438 100644
--- a/bionemo-recipes/models/llama3/requirements.txt
+++ b/bionemo-recipes/models/llama3/requirements.txt
@@ -1,5 +1,5 @@
lm-eval # For testing
torch
torchao!=0.14.0
-transformer_engine[pytorch]
+transformer_engine[pytorch] @ git+https://github.com/cspades/TransformerEngine.git@7e0d3a9b9cca4243bcf9f233f79fcb0c09795e3f
transformers
diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_train_two_gpu.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_train_two_gpu.py
index 4b98cd36a3..9085803916 100644
--- a/bionemo-recipes/recipes/esm2_native_te/tests/test_train_two_gpu.py
+++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_train_two_gpu.py
@@ -135,7 +135,7 @@ def test_multi_gpu_train_te_fsdp2_cp(tmp_path, recipe_path):
[
"torchrun",
"--nproc_per_node=2",
- "train_fsdp2_cp.py",
+ "train_fsdp2_nd_parallel.py",
"--config-name",
"L0_sanity_cp",
"num_train_steps=4",
diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py
index 12d4313941..0ce3a0260d 100644
--- a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py
+++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py
@@ -87,7 +87,7 @@ def main(args: DictConfig) -> float | None:
)
if args.use_fp32_master_weights:
- raise ValueError("FP32 master weights are not supported with DDP+CP. Use train_fsdp2_cp.py instead.")
+ raise ValueError("FP32 master weights are not supported with DDP+CP. Use train_fsdp2_nd_parallel.py instead.")
# Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D".
# Note: token_dropout is set to False because it's not compatible with context parallelism.
diff --git a/bionemo-recipes/recipes/llama3_native_te/README.md b/bionemo-recipes/recipes/llama3_native_te/README.md
index d99a3001c0..e3e6b55d09 100644
--- a/bionemo-recipes/recipes/llama3_native_te/README.md
+++ b/bionemo-recipes/recipes/llama3_native_te/README.md
@@ -18,7 +18,7 @@ bionemo-framework repository. You can download a zipped directory of this folder
| Model | BF16 | FP8[1] | THD Input Format | FP8 with THD Input Format | MXFP8[2] | Context Parallelism | Tensor Parallelism |
| ---------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | ------------------ |
-| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 |
+| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
✅: Supported
🚧: Under development
@@ -94,7 +94,7 @@ This recipe supports distributed training using DDP, FSDP2, and FSDP2 with Conte
- [Distributed Data Parallel (DDP)](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html), shown in `train_ddp.py`
- [Fully Sharded Data Parallel 2 (FSDP2)](https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html), shown in `train_fsdp2.py`
-- FSDP2 with Context Parallelism, shown in `train_fsdp2_cp.py`
+- FSDP2 with Context Parallelism, shown in `train_fsdp2_nd_parallel.py`
## Commands to Launch Training
@@ -193,13 +193,30 @@ python train_fsdp2.py --config-name L0_sanity \
### Context Parallel Training
Context parallelism splits each sequence across multiple GPUs along the sequence dimension, enabling training with very
-long sequences. Use `train_fsdp2_cp.py` with the `L0_sanity_cp` configuration and set `cp_size` to the number of context
+long sequences. Use `train_fsdp2_nd_parallel.py` with the `L0_sanity_cp` configuration and set `cp_size` to the number of context
parallelism ranks. Works with both BSHD (no padding) and THD (padding) input formats. Only TE models are supported.
```bash
-torchrun --nproc_per_node=4 train_fsdp2_cp.py --config-name L0_sanity_cp cp_size=2
+torchrun --nproc_per_node=4 train_fsdp2_nd_parallel.py --config-name L0_sanity_cp cp_size=2
```
+### Tensor Parallel Training
+
+Tensor parallelism shards model activations and weights along the hidden dimension across multiple GPUs, and is compatible with
+context parallelism and FSDP2 in TransformerEngine. Use `train_fsdp2_nd_parallel` with the `L2_sanity_nd` config and optionally
+set `tp_size` and `cp_size` such that the product of these sizes is less than or equal to the number of GPUs. When the FSDP and
+TP sharding dimensions coincide, FSDP2 will automatically create `_StridedShard` placements to ensure that TP shards are the
+result of the all-gather during training. Only TE models are supported, and refer to the `set_device_mesh` API on TE modules for
+more information about how `DeviceMesh` and `DTensor` interact with TransformerEngine.
+
+```bash
+torchrun --nproc_per_node=4 train_fsdp2_nd_parallel.py --config-name L2_sanity_nd cp_size=2 tp_size=2
+```
+
+Note that TransformerEngine TP is compatible with DTensor-based TP. For instance, the `torch.nn.Embedding` is weight-tied to the
+Llama LM Head, but the embedding layer is tensor-parallelized with `parallelize_module` while the head is tensor-parallelized with
+TransformerEngine.
+
## Downloading Pre-Training Data For Offline Training
This recipe is configured to use genomic sequences. The default configuration uses a local test file
diff --git a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py
index 2dc5d10dcf..b75fd00549 100644
--- a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py
+++ b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py
@@ -233,9 +233,28 @@ class AppState(Stateful):
epoch: int = 0
def state_dict(self):
- """Get the state dict for the model, optimizer, scheduler, and step."""
+ """Get the state dict for the model, optimizer, scheduler, and step.
+
+ This factory both retrieves the model state dictionary when saving
+ checkpoints and initializes a destination for the state read from
+ DCP checkpoint files when loading checkpoints.
+ """
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
- model_state_dict = {k: v for k, v in model_state_dict.items() if not k.endswith("_extra_state")}
+ for fqn in list(model_state_dict.keys()):
+ # Get the model parameter.
+ model_param = model_state_dict[fqn]
+ if isinstance(model_param, DTensor):
+ model_param = model_param.to_local()
+ if model_param.numel() == 0 and fqn in optimizer_state_dict["state"]:
+ # Empty model parameter. Clear the associated optimizer state
+ # when initializing the optimizer state upon DCP load, because
+ # empty optimizer state DTensors are not checkpointed with DCP,
+ # yet get_state_dict / _init_optim_state produce empty Tensors.
+ # TransformerEngine uses empty Tensors for dummy Parameters.
+ optimizer_state_dict["state"][fqn] = {}
+ if fqn.endswith("_extra_state"):
+ # Evict `_extra_state` quantization data from model checkpoint.
+ model_state_dict.pop(fqn)
return {
"model": model_state_dict,
"optim": optimizer_state_dict,
@@ -245,12 +264,18 @@ def state_dict(self):
}
def load_state_dict(self, state_dict: dict):
- """Load the state dict for the model, optimizer, scheduler, and step."""
+ """Load the state dict for the model, optimizer, scheduler, and step.
+
+ Given the checkpoint-loaded state_dict, set the state of the model,
+ optimizer, scheduler, step, and epoch to the values in state_dict.
+ """
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
+ # Non-strict checkpoint loading ignores empty optimizer states,
+ # skips loading non-FP8 checkpoint weights (e.g. _extra_state).
options=StateDictOptions(strict=False),
)
self.scheduler.load_state_dict(state_dict["scheduler"])
diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_tp.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_tp.yaml
new file mode 100644
index 0000000000..601fdf097c
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_tp.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - L0_sanity
+ - _self_
+
+tp_size: 2 # Tensor Parallel sharding factor
+cp_size: 1
+
+use_sequence_packing: false
+
+config_kwargs:
+ attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware.
+ self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs.
+ tensor_parallel: true # Tensor Parallelism for TE
+ sequence_parallel: false # Sequence parallelism for LayerNorm on TP ranks.
+ tp_size: ${tp_size} # Tensor Parallel Size
diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_sanity_nd.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_sanity_nd.yaml
new file mode 100644
index 0000000000..b31d8c9500
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_sanity_nd.yaml
@@ -0,0 +1,27 @@
+defaults:
+ - L0_sanity
+ - _self_
+
+tp_size: 2 # TP Sharding
+cp_size: 2 # FSDP-CP Sharding
+
+dataset:
+ # CP2 * (8 for FP8 Activations, 16 for FP8 Parameters)
+ pad_sequences_to_be_divisible_by: 32
+
+fp8_config:
+ enabled: true
+ fp8_recipe: transformer_engine.common.recipe.DelayedScaling
+ fp8_format: "HYBRID"
+ fp8_recipe_kwargs: {}
+
+checkpoint:
+ ckpt_dir: ./fsdp_nd_ckpts
+ save_final_model: true
+
+config_kwargs:
+ attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware.
+ self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs.
+ tensor_parallel: true # Tensor Parallelism for TE
+ sequence_parallel: true # Sequence parallelism for LayerNorm on TP ranks.
+ tp_size: ${tp_size} # Tensor Parallel Size
diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
index 9302a0758d..fdd4f0f0b8 100644
--- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
+++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
@@ -12,6 +12,10 @@ use_meta_device: true
# We leave this off by default since we don't see much of a performance improvement with TE layers.
use_torch_compile: false
+# Default parallelism sizes.
+tp_size: 1
+cp_size: 1
+
use_sequence_packing: false
dataset:
diff --git a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
index 713a26b8be..0a7cde0037 100644
--- a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
+++ b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
@@ -31,6 +31,8 @@
import transformer_engine.common.recipe
import transformer_engine.pytorch
import transformers
+from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
+from torch.distributed.tensor.placement_types import Replicate
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention.inference import PagedKVCacheManager
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
@@ -58,6 +60,9 @@ class NVLlamaConfig(LlamaConfig):
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
attn_input_format: str = "thd"
self_attn_mask_type: str = "padding_causal"
+ tensor_parallel: bool = False
+ sequence_parallel: bool = False
+ tp_size: int = 1
def __init__(
self,
@@ -148,6 +153,8 @@ def __init__(
config: LlamaConfig,
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ nvte_tp_mesh: torch.distributed.DeviceMesh | None = None,
+ nvte_weight_mesh: torch.distributed.DeviceMesh | None = None,
):
"""Initialize the NVLlama model.
@@ -155,6 +162,8 @@ def __init__(
config: The configuration of the model.
fp8_recipe: The FP8 recipe for the model.
fp4_recipe: The FP4 recipe for the model.
+ nvte_tp_mesh: TP DeviceMesh for the model.
+ nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
"""
super().__init__(config)
self.config = config
@@ -162,6 +171,8 @@ def __init__(
self.vocab_size = config.vocab_size
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
+ self.tp_mesh = nvte_tp_mesh
+ self.weight_mesh = nvte_weight_mesh
if self.config.layer_precision is None:
if fp8_recipe is not None and fp4_recipe is not None:
@@ -180,6 +191,27 @@ def __init__(
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype)
+ # Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP.
+ if config.tensor_parallel:
+ assert self.tp_mesh is not None, "[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh."
+ assert self.tp_mesh.size() == config.tp_size, (
+ f"[NVLlamaModel] DeviceMesh TP size ({self.tp_mesh.size()}) "
+ f"does not match configured TP size ({config.tp_size}).",
+ )
+ # NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding
+ # during HuggingFace post-init, this will automatically convert the TELinear head
+ # weight into a DTensor with the correct sharding placements prior to FSDP2
+ # fully_shard(), and no need to call TELinear.set_device_mesh().
+ parallelize_module(
+ self.embed_tokens,
+ self.tp_mesh,
+ # Un-sharded output activations for compatible input to TETransformer.
+ # NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1)
+ # RowwiseParallel doesn't support output_layouts=Replicate() with
+ # torch.compile: https://github.com/pytorch/torchtitan/issues/534
+ ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
+ )
+
def _init_method(x):
torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)
@@ -207,6 +239,11 @@ def _init_method(x):
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=_init_method,
output_layer_init_method=_init_method,
+ set_parallel_mode=config.tensor_parallel,
+ sequence_parallel=config.sequence_parallel,
+ tp_size=config.tp_size,
+ tp_mesh=self.tp_mesh,
+ weight_mesh=self.weight_mesh,
)
]
@@ -217,6 +254,8 @@ def _init_method(x):
dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
)
+ # Norm modules are non-Base TransformerEngine modules that require a manual call for TP.
+ self.norm.set_device_mesh(tp_mesh=self.tp_mesh, weight_mesh=self.weight_mesh)
# We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original
# LlamaRotaryEmbedding.
@@ -393,6 +432,8 @@ def __init__(
config,
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ nvte_tp_mesh: torch.distributed.DeviceMesh | None = None,
+ nvte_weight_mesh: torch.distributed.DeviceMesh | None = None,
):
"""Initialize the NVLlamaForCausalLM model.
@@ -400,10 +441,21 @@ def __init__(
config: The configuration of the model.
fp8_recipe: The FP8 recipe for the model.
fp4_recipe: The FP4 recipe for the model.
+ nvte_tp_mesh: TP DeviceMesh for the model.
+ nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
"""
super().__init__(config)
- self.model = NVLlamaModel(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+ self.model = NVLlamaModel(
+ config,
+ fp8_recipe=fp8_recipe,
+ fp4_recipe=fp4_recipe,
+ nvte_tp_mesh=nvte_tp_mesh,
+ nvte_weight_mesh=nvte_weight_mesh,
+ )
+ self.config = config
self.vocab_size = config.vocab_size
+ self.tp_mesh = nvte_tp_mesh
+ self.weight_mesh = nvte_weight_mesh
with transformer_engine.pytorch.quantized_model_init(enabled=False):
self.lm_head = transformer_engine.pytorch.Linear(
config.hidden_size,
@@ -412,9 +464,25 @@ def __init__(
params_dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
+ parallel_mode="row" if config.tensor_parallel else None,
+ # This scatters your output, not ever needed for final layer.
+ # Will all-reduce the output instead, as required.
+ sequence_parallel=False,
+ tp_size=config.tp_size,
)
-
- # Initialize weights and apply final processing
+ if config.tensor_parallel:
+ if config.tie_word_embeddings:
+ # Head weights have already been tied to the embedding weights.
+ # Just set the tensor parallel group for TE.
+ # No parameter quantization either, so no need for weight_mesh.
+ self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())
+ else:
+ # Head weights are not tied to the embedding weights. Need to
+ # wrap the LM head weight as a DTensor with TE.
+ # No parameter quantization either, so no need for weight_mesh.
+ self.lm_head.set_device_mesh(tp_mesh=self.tp_mesh)
+
+ # Initialize weights and apply final processing. Ties weights.
self.post_init()
def forward(
@@ -467,6 +535,14 @@ def forward(
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ if self.config.tensor_parallel:
+ # If using TP, shard your activation across the TP group,
+ # to support row-wise tensor parallelism in the LM head.
+ # Use ... to support both BSHD (3D) and THD (2D) hidden states.
+ tp_rank = self.tp_mesh.get_local_rank()
+ tp_stride = hidden_states.shape[-1] // self.config.tp_size
+ hidden_states = hidden_states[..., tp_rank * tp_stride : (tp_rank + 1) * tp_stride]
+
with transformer_engine.pytorch.autocast(enabled=False):
if hidden_states.ndim == 3:
logits = self.lm_head(hidden_states[:, slice_indices, :])
diff --git a/bionemo-recipes/recipes/llama3_native_te/requirements.txt b/bionemo-recipes/recipes/llama3_native_te/requirements.txt
index 073d9b39e3..ad7e7fb061 100644
--- a/bionemo-recipes/recipes/llama3_native_te/requirements.txt
+++ b/bionemo-recipes/recipes/llama3_native_te/requirements.txt
@@ -5,7 +5,7 @@ torchao!=0.14.0
torchdata
torchmetrics
tqdm
-transformer_engine[pytorch]
+transformer_engine[pytorch] @ git+https://github.com/cspades/TransformerEngine.git@7e0d3a9b9cca4243bcf9f233f79fcb0c09795e3f
transformers
wandb
zstandard
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py
index 0cbaa64673..dda13749a1 100644
--- a/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py
@@ -706,7 +706,7 @@ def test_cp_dataloader(tokenizer_path):
cp_mesh = device_mesh["cp"]
- # Create the context-parallel dataloader directly following the pattern in train_fsdp2_cp.py
+ # Create the context-parallel dataloader directly following the pattern in train_fsdp2_nd_parallel.py
if cp_mesh.get_local_rank() == 0:
train_dataloader, _ = create_thd_dataloader(
distributed_config=dist_config,
@@ -799,7 +799,7 @@ def test_cp_dataloader_multi_gpu(recipe_path, dataset_path):
cp_mesh = device_mesh["cp"]
- # Create the context-parallel dataloader directly following the pattern in train_fsdp2_cp.py
+ # Create the context-parallel dataloader directly following the pattern in train_fsdp2_nd_parallel.py
if cp_mesh.get_local_rank() == 0:
train_dataloader, _ = create_thd_dataloader(
distributed_config=dist_config,
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py
index 889f57345b..8f8b22f5f5 100644
--- a/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py
@@ -40,7 +40,7 @@
from train_ddp import main as main_ddp
from train_fsdp2 import main as main_fsdp2
-from train_fsdp2_cp import main as main_fsdp2_cp
+from train_fsdp2_nd_parallel import main as main_fsdp2_cp
os.environ["WANDB_DISABLED"] = "true"
@@ -329,7 +329,7 @@ def test_checkpoint_save_and_load_two_processes_fsdp2_with_context_parallelism(r
_run_multi_process_checkpoint_test(
recipe_path,
tmp_path,
- "train_fsdp2_cp.py",
+ "train_fsdp2_nd_parallel.py",
ckpt_subdir_name="train_fsdp2",
extra_overrides=["checkpoint.async_save=false", "cp_size=2"],
is_ddp=False,
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
index 89e85068de..5a3e1994d0 100644
--- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
@@ -24,7 +24,7 @@
from train_ddp import main as main_ddp
from train_fsdp2 import main as main_fsdp2
-from train_fsdp2_cp import main as main_fsdp2_cp
+from train_fsdp2_nd_parallel import main as main_fsdp2_cp
# TODO(@jomitchell): Delete once https://nvbugspro.nvidia.com/bug/5458694 is fixed.
@@ -480,6 +480,61 @@ def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path):
assert stats_log.stat().st_size > 0, "Statistics log file is empty"
+def test_sanity_nd_parallel_tp1_bshd(tmp_path, recipe_path):
+ """Test ND-parallel training with tensor_parallel=True and tp_size=1 (trivial TP group), BSHD.
+
+ This test validates that all TP code paths in NVLlamaModel and NVLlamaForCausalLM execute
+ correctly with a single-rank TP mesh:
+ - parallelize_module on embed_tokens (ColwiseParallel)
+ - TransformerLayer TP mode flags
+ - lm_head row-parallel mode and set_tensor_parallel_group
+ - Hidden-state activation slicing in NVLlamaForCausalLM.forward
+ """
+ with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
+ sanity_config = compose(
+ config_name="L0_sanity_tp",
+ overrides=[
+ f"+wandb.dir={tmp_path}",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "num_train_steps=10",
+ "tp_size=1",
+ "checkpoint.resume_from_checkpoint=false",
+ ],
+ )
+
+ final_loss = main_fsdp2_cp(sanity_config)
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite"
+
+
+def test_sanity_nd_parallel_tp1_sequence_parallel_bshd(tmp_path, recipe_path):
+ """Test ND-parallel training with tensor_parallel=True, sequence_parallel=True, tp_size=1, BSHD.
+
+ Validates that the sequence-parallel RMSNorm (set_device_mesh on the final norm) does not
+ break forward/backward even when the TP group is a single rank.
+ """
+ with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
+ sanity_config = compose(
+ config_name="L0_sanity_tp",
+ overrides=[
+ f"+wandb.dir={tmp_path}",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "num_train_steps=10",
+ "tp_size=1",
+ "config_kwargs.sequence_parallel=true",
+ "checkpoint.resume_from_checkpoint=false",
+ ],
+ )
+
+ final_loss = main_fsdp2_cp(sanity_config)
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite"
+
+
@requires_fp8
def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path):
"""Test that FP8 stats logging works with FSDP2."""
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py
index 0993f5fb46..cac6b11e13 100644
--- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py
@@ -199,7 +199,7 @@ def test_multi_gpu_train_te_fsdp2_cp_bshd(tmp_path, recipe_path):
"torchrun",
"--standalone",
"--nproc_per_node=2",
- "train_fsdp2_cp.py",
+ "train_fsdp2_nd_parallel.py",
"--config-name",
"L0_sanity_cp",
"num_train_steps=10",
@@ -223,7 +223,7 @@ def test_multi_gpu_train_te_fsdp2_cp_thd(tmp_path, recipe_path):
"torchrun",
"--standalone",
"--nproc_per_node=2",
- "train_fsdp2_cp.py",
+ "train_fsdp2_nd_parallel.py",
"--config-name",
"L0_sanity_cp",
"num_train_steps=10",
@@ -238,6 +238,111 @@ def test_multi_gpu_train_te_fsdp2_cp_thd(tmp_path, recipe_path):
)
+@requires_multi_gpu
+def test_multi_gpu_train_te_fsdp2_tp_bshd(tmp_path, recipe_path):
+ """Test FSDP2 with tensor parallelism on 2 GPUs using BSHD input format.
+
+ Validates:
+ - The 1-D TP device mesh (dp=1, cp=1, tp=2) is created and used correctly
+ - Embedding weights are ColwiseParallel-sharded across 2 TP ranks
+ - TransformerLayer TP mode shards QKV/FFN weights across ranks
+ - Row-wise parallel LM head with hidden-state slicing before forward
+ """
+ run_train_cmd(
+ [
+ "torchrun",
+ "--standalone",
+ "--nproc_per_node=2",
+ "train_fsdp2_nd_parallel.py",
+ "--config-name",
+ "L0_sanity_tp",
+ "num_train_steps=10",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ ],
+ recipe_path,
+ )
+
+
+@requires_multi_gpu
+@requires_datacenter_hardware
+def test_multi_gpu_train_te_fsdp2_tp_thd(tmp_path, recipe_path):
+ """Test FSDP2 with tensor parallelism on 2 GPUs using THD (sequence-packed) input format.
+
+ Validates:
+ - TP=2, CP=1 with sequence-packing / THD attention format
+ - _unpad_input / _pad_input round-trip works alongside TP activation sharding
+ - padding_causal mask type is compatible with row-wise parallel LM head
+ """
+ run_train_cmd(
+ [
+ "torchrun",
+ "--standalone",
+ "--nproc_per_node=2",
+ "train_fsdp2_nd_parallel.py",
+ "--config-name",
+ "L0_sanity_tp",
+ "num_train_steps=10",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "use_sequence_packing=true",
+ "config_kwargs.attn_input_format=thd",
+ "config_kwargs.self_attn_mask_type=padding_causal",
+ ],
+ recipe_path,
+ )
+
+
+@requires_multi_gpu
+def test_multi_gpu_train_te_fsdp2_tp_sequence_parallel_bshd(tmp_path, recipe_path):
+ """Test FSDP2 with tensor parallelism + sequence parallelism on 2 GPUs, BSHD.
+
+ Validates that sequence parallelism (LayerNorm activations sharded across TP ranks)
+ works alongside standard tensor parallelism without errors.
+ """
+ run_train_cmd(
+ [
+ "torchrun",
+ "--standalone",
+ "--nproc_per_node=2",
+ "train_fsdp2_nd_parallel.py",
+ "--config-name",
+ "L0_sanity_tp",
+ "num_train_steps=10",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "config_kwargs.sequence_parallel=true",
+ ],
+ recipe_path,
+ )
+
+
+@requires_multi_gpu
+def test_multi_gpu_train_te_fsdp2_tp_bshd_with_checkpointing(tmp_path, recipe_path):
+ """Test FSDP2 TP training on 2 GPUs with checkpoint saving.
+
+ Validates:
+ - Sharded FSDP2 checkpoints are written correctly while TP is active
+ - The expected checkpoint directory structure is present after training
+ """
+ run_train_cmd(
+ [
+ "torchrun",
+ "--standalone",
+ "--nproc_per_node=2",
+ "train_fsdp2_nd_parallel.py",
+ "--config-name",
+ "L0_sanity_tp",
+ "num_train_steps=10",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "checkpoint.save_every_n_steps=5",
+ "checkpoint.resume_from_checkpoint=false",
+ ],
+ recipe_path,
+ )
+
+ ckpt_dir = tmp_path / "train_fsdp2"
+ assert ckpt_dir.exists(), f"Checkpoint directory not created: {ckpt_dir}"
+ assert (ckpt_dir / "step_5").exists(), "Checkpoint at step 5 not found"
+
+
nsys_available = subprocess.run(["which", "nsys"], check=False, capture_output=True).returncode == 0
diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
index da19daa2a7..ade90cef1b 100644
--- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
+++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
@@ -19,7 +19,7 @@
the memory of a single GPU. Supports both TE-accelerated (NVLlamaForCausalLM) and standard
HuggingFace (LlamaForCausalLM) models.
-For very long sequences, use ``train_fsdp2_cp.py`` which adds Context Parallelism on top of FSDP2.
+For very long sequences, use ``train_fsdp2_nd_parallel.py`` which adds Context Parallelism on top of FSDP2.
"""
import gc
diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_nd_parallel.py
similarity index 81%
rename from bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
rename to bionemo-recipes/recipes/llama3_native_te/train_fsdp2_nd_parallel.py
index eaf1a1b39f..23cebe6318 100644
--- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
+++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_nd_parallel.py
@@ -13,14 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""FSDP2 with Context Parallelism training script for Llama 3 with TransformerEngine.
+"""FSDP2 with Tensor & Context Parallelism training script for Llama 3 with TransformerEngine.
-Combines Fully Sharded Data Parallel v2 with Context Parallelism (CP), where each sequence is
-split across multiple GPUs along the sequence dimension. This is useful for training with very long
-sequences that do not fit into a single GPU's memory even with FSDP2 alone. Only supports
-TE-accelerated models (NVLlamaForCausalLM).
+Combines Fully Sharded Data Parallel v2 with Tensor Parallelism (TP) and Context Parallelism (CP).
+In Context Parallelism, each sequence is split across multiple GPUs along the sequence dimension,
+which is useful for training on extremely long sequences that exhaust activation memory.
+In Tensor Parallelism, weights and activations are sharded on the hidden dim across multiple GPUs,
+which is useful for sharding model weights and activations unlike FSDP which only shards weights.
+Only supports TE-accelerated models (NVLlamaForCausalLM).
-For standard FSDP2 training without context parallelism, use ``train_fsdp2.py`` instead.
+For standard FSDP2 training without N-D parallelism, use ``train_fsdp2.py`` instead.
"""
import gc
@@ -59,7 +61,7 @@
@hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2")
def main(args: DictConfig) -> float | None:
- """Train Llama3 with TE layers using FSDP2 with Context Parallelism.
+ """Train Llama3 with TE layers using FSDP2, CP, and TP.
Returns:
float: The loss value for the final batch.
@@ -73,8 +75,8 @@ def main(args: DictConfig) -> float | None:
device_mesh = init_device_mesh(
"cuda",
- mesh_shape=(dist_config.world_size // args.cp_size, args.cp_size),
- mesh_dim_names=("dp", "cp"),
+ mesh_shape=(dist_config.world_size // (args.cp_size * args.tp_size), args.cp_size, args.tp_size),
+ mesh_dim_names=("dp", "cp", "tp"),
)
logger.info("Created device mesh: %s", device_mesh)
@@ -94,11 +96,22 @@ def main(args: DictConfig) -> float | None:
config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
with torch.device("meta") if args.use_meta_device else nullcontext():
- model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+ model = NVLlamaForCausalLM(
+ config,
+ fp8_recipe=fp8_recipe,
+ fp4_recipe=fp4_recipe,
+ # Only pass DeviceMesh to TransformerEngine if using Tensor Parallelism
+ # or if your DeviceMesh has multiple weight-sharding dimensions.
+ nvte_tp_mesh=device_mesh["tp"] if config.tensor_parallel else None,
+ # nvte_weight_mesh is only required for Float8CurrentScaling parameters.
+ nvte_weight_mesh=device_mesh["dp", "cp", "tp"]._flatten("weight_mesh") if config.tensor_parallel else None,
+ )
logger.info("Initialized Model:\n%s", model)
# --- Distributed Wrapping (FSDP2 + CP) ---
+
+ # Create a flattened mesh for FSDP2-CP sharding. This will shard the model across both the DP and CP ranks.
cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp")
# Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers.
@@ -107,7 +120,7 @@ def main(args: DictConfig) -> float | None:
fully_shard(layer, mesh=cp_dp_mesh)
fully_shard(model, mesh=cp_dp_mesh)
- # Attach the CP group to the model.
+ # Attach the CP ProcessGroup to the TransformerEngine model.
for layer in model.model.layers:
layer.set_context_parallel_group(
device_mesh["cp"].get_group(),
@@ -136,9 +149,12 @@ def main(args: DictConfig) -> float | None:
logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2")
OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2)
- # We only create the dataloader on rank 0, which is responsible for loading data for all CP (and eventually TP)
- # ranks. This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline.
- if device_mesh["cp"].get_local_rank() == 0:
+ # We only create the dataloader on rank 0, which is responsible for loading data for all CP (and TP) ranks.
+ # This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline.
+ cp_tp_mesh = device_mesh["cp", "tp"]._flatten(mesh_dim_name="cp_tp")
+ if cp_tp_mesh.get_local_rank() == 0:
+ # We only create the dataloader on CP-TP Rank 0 and pass it to a ContextParallelDataLoaderWrapper
+ # that will shard, replicate, and distribute the data across the flattened CP and TP group.
if args.use_sequence_packing:
train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
else:
@@ -155,8 +171,8 @@ def main(args: DictConfig) -> float | None:
train_dataloader = None
dataset_or_sampler = None
- # On all ranks, we create a ContextParallelDataLoaderWrapper that broadcasts the data from cp rank 0.
- train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, device_mesh["cp"])
+ # Deliver CP-sharded replicates to a flattened CP-TP mesh.
+ train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_tp_mesh)
# --- Checkpoint Resume ---
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None
@@ -169,7 +185,6 @@ def main(args: DictConfig) -> float | None:
ckpt_path=ckpt_path,
dist_config=dist_config,
dataloader=train_dataloader,
- process_group=cp_dp_mesh.get_group(),
)
logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch)
else:
@@ -234,7 +249,6 @@ def main(args: DictConfig) -> float | None:
epoch=epoch,
dist_config=dist_config,
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
- process_group=cp_dp_mesh.get_group(),
max_checkpoints=args.checkpoint.max_checkpoints,
async_save=args.checkpoint.async_save,
)
diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py
index 713a26b8be..0a7cde0037 100644
--- a/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py
+++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py
@@ -31,6 +31,8 @@
import transformer_engine.common.recipe
import transformer_engine.pytorch
import transformers
+from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
+from torch.distributed.tensor.placement_types import Replicate
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention.inference import PagedKVCacheManager
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
@@ -58,6 +60,9 @@ class NVLlamaConfig(LlamaConfig):
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
attn_input_format: str = "thd"
self_attn_mask_type: str = "padding_causal"
+ tensor_parallel: bool = False
+ sequence_parallel: bool = False
+ tp_size: int = 1
def __init__(
self,
@@ -148,6 +153,8 @@ def __init__(
config: LlamaConfig,
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ nvte_tp_mesh: torch.distributed.DeviceMesh | None = None,
+ nvte_weight_mesh: torch.distributed.DeviceMesh | None = None,
):
"""Initialize the NVLlama model.
@@ -155,6 +162,8 @@ def __init__(
config: The configuration of the model.
fp8_recipe: The FP8 recipe for the model.
fp4_recipe: The FP4 recipe for the model.
+ nvte_tp_mesh: TP DeviceMesh for the model.
+ nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
"""
super().__init__(config)
self.config = config
@@ -162,6 +171,8 @@ def __init__(
self.vocab_size = config.vocab_size
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
+ self.tp_mesh = nvte_tp_mesh
+ self.weight_mesh = nvte_weight_mesh
if self.config.layer_precision is None:
if fp8_recipe is not None and fp4_recipe is not None:
@@ -180,6 +191,27 @@ def __init__(
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype)
+ # Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP.
+ if config.tensor_parallel:
+ assert self.tp_mesh is not None, "[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh."
+ assert self.tp_mesh.size() == config.tp_size, (
+ f"[NVLlamaModel] DeviceMesh TP size ({self.tp_mesh.size()}) "
+ f"does not match configured TP size ({config.tp_size}).",
+ )
+ # NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding
+ # during HuggingFace post-init, this will automatically convert the TELinear head
+ # weight into a DTensor with the correct sharding placements prior to FSDP2
+ # fully_shard(), and no need to call TELinear.set_device_mesh().
+ parallelize_module(
+ self.embed_tokens,
+ self.tp_mesh,
+ # Un-sharded output activations for compatible input to TETransformer.
+ # NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1)
+ # RowwiseParallel doesn't support output_layouts=Replicate() with
+ # torch.compile: https://github.com/pytorch/torchtitan/issues/534
+ ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
+ )
+
def _init_method(x):
torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)
@@ -207,6 +239,11 @@ def _init_method(x):
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=_init_method,
output_layer_init_method=_init_method,
+ set_parallel_mode=config.tensor_parallel,
+ sequence_parallel=config.sequence_parallel,
+ tp_size=config.tp_size,
+ tp_mesh=self.tp_mesh,
+ weight_mesh=self.weight_mesh,
)
]
@@ -217,6 +254,8 @@ def _init_method(x):
dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
)
+ # Norm modules are non-Base TransformerEngine modules that require a manual call for TP.
+ self.norm.set_device_mesh(tp_mesh=self.tp_mesh, weight_mesh=self.weight_mesh)
# We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original
# LlamaRotaryEmbedding.
@@ -393,6 +432,8 @@ def __init__(
config,
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ nvte_tp_mesh: torch.distributed.DeviceMesh | None = None,
+ nvte_weight_mesh: torch.distributed.DeviceMesh | None = None,
):
"""Initialize the NVLlamaForCausalLM model.
@@ -400,10 +441,21 @@ def __init__(
config: The configuration of the model.
fp8_recipe: The FP8 recipe for the model.
fp4_recipe: The FP4 recipe for the model.
+ nvte_tp_mesh: TP DeviceMesh for the model.
+ nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
"""
super().__init__(config)
- self.model = NVLlamaModel(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+ self.model = NVLlamaModel(
+ config,
+ fp8_recipe=fp8_recipe,
+ fp4_recipe=fp4_recipe,
+ nvte_tp_mesh=nvte_tp_mesh,
+ nvte_weight_mesh=nvte_weight_mesh,
+ )
+ self.config = config
self.vocab_size = config.vocab_size
+ self.tp_mesh = nvte_tp_mesh
+ self.weight_mesh = nvte_weight_mesh
with transformer_engine.pytorch.quantized_model_init(enabled=False):
self.lm_head = transformer_engine.pytorch.Linear(
config.hidden_size,
@@ -412,9 +464,25 @@ def __init__(
params_dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
+ parallel_mode="row" if config.tensor_parallel else None,
+ # This scatters your output, not ever needed for final layer.
+ # Will all-reduce the output instead, as required.
+ sequence_parallel=False,
+ tp_size=config.tp_size,
)
-
- # Initialize weights and apply final processing
+ if config.tensor_parallel:
+ if config.tie_word_embeddings:
+ # Head weights have already been tied to the embedding weights.
+ # Just set the tensor parallel group for TE.
+ # No parameter quantization either, so no need for weight_mesh.
+ self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())
+ else:
+ # Head weights are not tied to the embedding weights. Need to
+ # wrap the LM head weight as a DTensor with TE.
+ # No parameter quantization either, so no need for weight_mesh.
+ self.lm_head.set_device_mesh(tp_mesh=self.tp_mesh)
+
+ # Initialize weights and apply final processing. Ties weights.
self.post_init()
def forward(
@@ -467,6 +535,14 @@ def forward(
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ if self.config.tensor_parallel:
+ # If using TP, shard your activation across the TP group,
+ # to support row-wise tensor parallelism in the LM head.
+ # Use ... to support both BSHD (3D) and THD (2D) hidden states.
+ tp_rank = self.tp_mesh.get_local_rank()
+ tp_stride = hidden_states.shape[-1] // self.config.tp_size
+ hidden_states = hidden_states[..., tp_rank * tp_stride : (tp_rank + 1) * tp_stride]
+
with transformer_engine.pytorch.autocast(enabled=False):
if hidden_states.ndim == 3:
logits = self.lm_head(hidden_states[:, slice_indices, :])
diff --git a/bionemo-recipes/recipes/vit/train.py b/bionemo-recipes/recipes/vit/train.py
index db7c410638..5df1348523 100644
--- a/bionemo-recipes/recipes/vit/train.py
+++ b/bionemo-recipes/recipes/vit/train.py
@@ -22,7 +22,7 @@
import torch
import wandb
from hydra.core.hydra_config import HydraConfig
-from megatron_fsdp import fully_shard
+from megatron_fsdp.fully_shard import fully_shard
from omegaconf import OmegaConf
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
diff --git a/ci/lepton/model_convergence/configs/recipes/llama3_native_te.yaml b/ci/lepton/model_convergence/configs/recipes/llama3_native_te.yaml
index 82b9a2df3d..8178316790 100644
--- a/ci/lepton/model_convergence/configs/recipes/llama3_native_te.yaml
+++ b/ci/lepton/model_convergence/configs/recipes/llama3_native_te.yaml
@@ -87,7 +87,7 @@ products:
job_name: "llama3-native-L0-sanity"
# L0 sanity test with context parallelism
- config: L0_sanity_cp
- task_cmd: train_fsdp2_cp
+ task_cmd: train_fsdp2_nd_parallel
cp_enabled: true
wandb_name: "llama3_native__L0_sanity_cp__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "llama3-native-L0-sanity-cp"