Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bionemo-recipes/models/llama3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The Llama-3 implementation natively supports the following TransformerEngine-pro
| **Export to HuggingFace checkpoints** | ✅ Supported |
| **KV-cache inference** | ✅ Supported (including beam search) |
| **Context Parallelism** | ✅ Supported |
| **Tensor Parallelism** | 🚧 Under development |
| **Tensor Parallelism** | ✅ Supported |

Refer to [BioNeMo Recipes](../../recipes/llama3_native_te/README.md) for more details on how to use these features to accelerate model
training and inference with native PyTorch training loops.
Expand Down
82 changes: 79 additions & 3 deletions bionemo-recipes/models/llama3/modeling_llama_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import transformer_engine.common.recipe
import transformer_engine.pytorch
import transformers
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
from torch.distributed.tensor.placement_types import Replicate
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention.inference import PagedKVCacheManager
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
Expand Down Expand Up @@ -52,6 +54,9 @@ class NVLlamaConfig(LlamaConfig):
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
attn_input_format: str = "thd"
self_attn_mask_type: str = "padding_causal"
tensor_parallel: bool = False
sequence_parallel: bool = False
tp_size: int = 1

def __init__(
self,
Expand Down Expand Up @@ -142,20 +147,26 @@ def __init__(
config: LlamaConfig,
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
nvte_tp_mesh: torch.distributed.DeviceMesh | None = None,
nvte_weight_mesh: torch.distributed.DeviceMesh | None = None,
):
"""Initialize the NVLlama model.

Args:
config: The configuration of the model.
fp8_recipe: The FP8 recipe for the model.
fp4_recipe: The FP4 recipe for the model.
nvte_tp_mesh: TP DeviceMesh for the model.
nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
"""
super().__init__(config)
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
self.tp_mesh = nvte_tp_mesh
self.weight_mesh = nvte_weight_mesh

if self.config.layer_precision is None:
if fp8_recipe is not None and fp4_recipe is not None:
Expand All @@ -174,6 +185,27 @@ def __init__(

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype)

# Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP.
if config.tensor_parallel:
assert self.tp_mesh is not None, "[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh."
assert self.tp_mesh.size() == config.tp_size, (
f"[NVLlamaModel] DeviceMesh TP size ({self.tp_mesh.size()}) "
f"does not match configured TP size ({config.tp_size}).",
)
# NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding
# during HuggingFace post-init, this will automatically convert the TELinear head
# weight into a DTensor with the correct sharding placements prior to FSDP2
# fully_shard(), and no need to call TELinear.set_device_mesh().
parallelize_module(
self.embed_tokens,
self.tp_mesh,
# Un-sharded output activations for compatible input to TETransformer.
# NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1)
# RowwiseParallel doesn't support output_layouts=Replicate() with
# torch.compile: https://github.com/pytorch/torchtitan/issues/534
ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
)

def _init_method(x):
torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)

Expand Down Expand Up @@ -201,6 +233,11 @@ def _init_method(x):
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=_init_method,
output_layer_init_method=_init_method,
set_parallel_mode=config.tensor_parallel,
sequence_parallel=config.sequence_parallel,
tp_size=config.tp_size,
tp_mesh=self.tp_mesh,
weight_mesh=self.weight_mesh,
)
]

Expand All @@ -211,6 +248,8 @@ def _init_method(x):
dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
)
# Norm modules are non-Base TransformerEngine modules that require a manual call for TP.
self.norm.set_device_mesh(tp_mesh=self.tp_mesh, weight_mesh=self.weight_mesh)

# We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original
# LlamaRotaryEmbedding.
Expand Down Expand Up @@ -387,17 +426,30 @@ def __init__(
config,
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
nvte_tp_mesh: torch.distributed.DeviceMesh | None = None,
nvte_weight_mesh: torch.distributed.DeviceMesh | None = None,
):
"""Initialize the NVLlamaForCausalLM model.

Args:
config: The configuration of the model.
fp8_recipe: The FP8 recipe for the model.
fp4_recipe: The FP4 recipe for the model.
nvte_tp_mesh: TP DeviceMesh for the model.
nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
"""
super().__init__(config)
self.model = NVLlamaModel(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
self.model = NVLlamaModel(
config,
fp8_recipe=fp8_recipe,
fp4_recipe=fp4_recipe,
nvte_tp_mesh=nvte_tp_mesh,
nvte_weight_mesh=nvte_weight_mesh,
)
self.config = config
self.vocab_size = config.vocab_size
self.tp_mesh = nvte_tp_mesh
self.weight_mesh = nvte_weight_mesh
with transformer_engine.pytorch.quantized_model_init(enabled=False):
self.lm_head = transformer_engine.pytorch.Linear(
config.hidden_size,
Expand All @@ -406,9 +458,25 @@ def __init__(
params_dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
parallel_mode="row" if config.tensor_parallel else None,
# This scatters your output, not ever needed for final layer.
# Will all-reduce the output instead, as required.
sequence_parallel=False,
tp_size=config.tp_size,
)

# Initialize weights and apply final processing
if config.tensor_parallel:
if config.tie_word_embeddings:
# Head weights have already been tied to the embedding weights.
# Just set the tensor parallel group for TE.
# No parameter quantization either, so no need for weight_mesh.
self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())
else:
# Head weights are not tied to the embedding weights. Need to
# wrap the LM head weight as a DTensor with TE.
# No parameter quantization either, so no need for weight_mesh.
self.lm_head.set_device_mesh(tp_mesh=self.tp_mesh)

# Initialize weights and apply final processing. Ties weights.
self.post_init()

def forward(
Expand Down Expand Up @@ -461,6 +529,14 @@ def forward(
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep

if self.config.tensor_parallel:
# If using TP, shard your activation across the TP group,
# to support row-wise tensor parallelism in the LM head.
# Use ... to support both BSHD (3D) and THD (2D) hidden states.
tp_rank = self.tp_mesh.get_local_rank()
tp_stride = hidden_states.shape[-1] // self.config.tp_size
hidden_states = hidden_states[..., tp_rank * tp_stride : (tp_rank + 1) * tp_stride]

with transformer_engine.pytorch.autocast(enabled=False):
if hidden_states.ndim == 3:
logits = self.lm_head(hidden_states[:, slice_indices, :])
Expand Down
2 changes: 1 addition & 1 deletion bionemo-recipes/models/llama3/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
lm-eval # For testing
torch
torchao!=0.14.0
transformer_engine[pytorch]
transformer_engine[pytorch] @ git+https://github.com/cspades/TransformerEngine.git@7e0d3a9b9cca4243bcf9f233f79fcb0c09795e3f
transformers
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_multi_gpu_train_te_fsdp2_cp(tmp_path, recipe_path):
[
"torchrun",
"--nproc_per_node=2",
"train_fsdp2_cp.py",
"train_fsdp2_nd_parallel.py",
"--config-name",
"L0_sanity_cp",
"num_train_steps=4",
Expand Down
2 changes: 1 addition & 1 deletion bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main(args: DictConfig) -> float | None:
)

if args.use_fp32_master_weights:
raise ValueError("FP32 master weights are not supported with DDP+CP. Use train_fsdp2_cp.py instead.")
raise ValueError("FP32 master weights are not supported with DDP+CP. Use train_fsdp2_nd_parallel.py instead.")

# Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D".
# Note: token_dropout is set to False because it's not compatible with context parallelism.
Expand Down
25 changes: 21 additions & 4 deletions bionemo-recipes/recipes/llama3_native_te/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ bionemo-framework repository. You can download a zipped directory of this folder

| Model | BF16 | FP8<sup>[1]</sup> | THD Input Format | FP8 with THD Input Format | MXFP8<sup>[2]</sup> | Context Parallelism | Tensor Parallelism |
| ---------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | ------------------ |
| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 |
| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |

✅: Supported <br/>
🚧: Under development <br/>
Expand Down Expand Up @@ -94,7 +94,7 @@ This recipe supports distributed training using DDP, FSDP2, and FSDP2 with Conte

- [Distributed Data Parallel (DDP)](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html), shown in `train_ddp.py`
- [Fully Sharded Data Parallel 2 (FSDP2)](https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html), shown in `train_fsdp2.py`
- FSDP2 with Context Parallelism, shown in `train_fsdp2_cp.py`
- FSDP2 with Context Parallelism, shown in `train_fsdp2_nd_parallel.py`

## Commands to Launch Training

Expand Down Expand Up @@ -193,13 +193,30 @@ python train_fsdp2.py --config-name L0_sanity \
### Context Parallel Training

Context parallelism splits each sequence across multiple GPUs along the sequence dimension, enabling training with very
long sequences. Use `train_fsdp2_cp.py` with the `L0_sanity_cp` configuration and set `cp_size` to the number of context
long sequences. Use `train_fsdp2_nd_parallel.py` with the `L0_sanity_cp` configuration and set `cp_size` to the number of context
parallelism ranks. Works with both BSHD (no padding) and THD (padding) input formats. Only TE models are supported.

```bash
torchrun --nproc_per_node=4 train_fsdp2_cp.py --config-name L0_sanity_cp cp_size=2
torchrun --nproc_per_node=4 train_fsdp2_nd_parallel.py --config-name L0_sanity_cp cp_size=2
```

### Tensor Parallel Training

Tensor parallelism shards model activations and weights along the hidden dimension across multiple GPUs, and is compatible with
context parallelism and FSDP2 in TransformerEngine. Use `train_fsdp2_nd_parallel` with the `L2_sanity_nd` config and optionally
set `tp_size` and `cp_size` such that the product of these sizes is less than or equal to the number of GPUs. When the FSDP and
TP sharding dimensions coincide, FSDP2 will automatically create `_StridedShard` placements to ensure that TP shards are the
result of the all-gather during training. Only TE models are supported, and refer to the `set_device_mesh` API on TE modules for
more information about how `DeviceMesh` and `DTensor` interact with TransformerEngine.

```bash
torchrun --nproc_per_node=4 train_fsdp2_nd_parallel.py --config-name L2_sanity_nd cp_size=2 tp_size=2
```

Note that TransformerEngine TP is compatible with DTensor-based TP. For instance, the `torch.nn.Embedding` is weight-tied to the
Llama LM Head, but the embedding layer is tensor-parallelized with `parallelize_module` while the head is tensor-parallelized with
TransformerEngine.

## Downloading Pre-Training Data For Offline Training

This recipe is configured to use genomic sequences. The default configuration uses a local test file
Expand Down
31 changes: 28 additions & 3 deletions bionemo-recipes/recipes/llama3_native_te/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,28 @@ class AppState(Stateful):
epoch: int = 0

def state_dict(self):
"""Get the state dict for the model, optimizer, scheduler, and step."""
"""Get the state dict for the model, optimizer, scheduler, and step.
This factory both retrieves the model state dictionary when saving
checkpoints and initializes a destination for the state read from
DCP checkpoint files when loading checkpoints.
"""
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
model_state_dict = {k: v for k, v in model_state_dict.items() if not k.endswith("_extra_state")}
for fqn in list(model_state_dict.keys()):
# Get the model parameter.
model_param = model_state_dict[fqn]
if isinstance(model_param, DTensor):
model_param = model_param.to_local()
if model_param.numel() == 0 and fqn in optimizer_state_dict["state"]:
# Empty model parameter. Clear the associated optimizer state
# when initializing the optimizer state upon DCP load, because
# empty optimizer state DTensors are not checkpointed with DCP,
# yet get_state_dict / _init_optim_state produce empty Tensors.
# TransformerEngine uses empty Tensors for dummy Parameters.
optimizer_state_dict["state"][fqn] = {}
if fqn.endswith("_extra_state"):
# Evict `_extra_state` quantization data from model checkpoint.
model_state_dict.pop(fqn)
return {
"model": model_state_dict,
"optim": optimizer_state_dict,
Expand All @@ -245,12 +264,18 @@ def state_dict(self):
}

def load_state_dict(self, state_dict: dict):
"""Load the state dict for the model, optimizer, scheduler, and step."""
"""Load the state dict for the model, optimizer, scheduler, and step.
Given the checkpoint-loaded state_dict, set the state of the model,
optimizer, scheduler, step, and epoch to the values in state_dict.
"""
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
# Non-strict checkpoint loading ignores empty optimizer states,
# skips loading non-FP8 checkpoint weights (e.g. _extra_state).
options=StateDictOptions(strict=False),
)
self.scheduler.load_state_dict(state_dict["scheduler"])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
defaults:
- L0_sanity
- _self_

tp_size: 2 # Tensor Parallel sharding factor
cp_size: 1

use_sequence_packing: false

config_kwargs:
attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware.
self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs.
tensor_parallel: true # Tensor Parallelism for TE
sequence_parallel: false # Sequence parallelism for LayerNorm on TP ranks.
tp_size: ${tp_size} # Tensor Parallel Size
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
defaults:
- L0_sanity
- _self_

tp_size: 2 # TP Sharding
cp_size: 2 # FSDP-CP Sharding

dataset:
# CP2 * (8 for FP8 Activations, 16 for FP8 Parameters)
pad_sequences_to_be_divisible_by: 32

fp8_config:
enabled: true
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
fp8_format: "HYBRID"
fp8_recipe_kwargs: {}

checkpoint:
ckpt_dir: ./fsdp_nd_ckpts
save_final_model: true

config_kwargs:
attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware.
self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs.
tensor_parallel: true # Tensor Parallelism for TE
sequence_parallel: true # Sequence parallelism for LayerNorm on TP ranks.
tp_size: ${tp_size} # Tensor Parallel Size
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ use_meta_device: true
# We leave this off by default since we don't see much of a performance improvement with TE layers.
use_torch_compile: false

# Default parallelism sizes.
tp_size: 1
cp_size: 1

use_sequence_packing: false

dataset:
Expand Down
Loading
Loading