Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bionemo-recipes/recipes/esm2_native_te/.ruff.toml
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
extend = "../.ruff.toml"
ignore = ["C901"]
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Training config
model_tag: ??? # E.g., nvidia/esm2_t6_8M_UR50D, facebook/esm2_t6_8M_UR50D, or a local path (e.g ./example_8m_checkpoint)
num_train_steps: ???
grad_acc_steps: 1

# TODO: Once BIONEMO-2583 and BIONEMO-2719 are fixed, enable this by default and simplify training scripts to remove the
# meta-device conditional.
Expand Down
49 changes: 29 additions & 20 deletions bionemo-recipes/recipes/esm2_native_te/perf_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import wandb
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm
from transformers.modeling_outputs import MaskedLMOutput

from distributed_config import DistributedConfig

Expand Down Expand Up @@ -68,43 +67,47 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
# Log the entire args object to wandb for experiment tracking and reproducibility.
wandb.init(**args.wandb_init_args, config=self._run_config)
self._progress_bar = tqdm(total=args.num_train_steps, desc="Training")
self.num_tokens = 0
self.num_unpadded_tokens = 0
self.running_loss = 0
self.grad_acc_step_count = 0

def log_micro_step(self, batch, outputs):
"""Store data on micro step for accumulation metrics."""
self.grad_acc_step_count += 1
self.num_tokens += batch["input_ids"].numel()
self.num_unpadded_tokens += batch["input_ids"][batch["input_ids"] != 1].numel()
self.running_loss += outputs.loss.item()
# Handle sequence packing for torchmetrics calculation.
if outputs.logits.dim() < 3:
outputs.logits = outputs.logits.unsqueeze(0)
self.metrics["train/perplexity"].update(outputs.logits, batch["labels"])

def log_step(
self,
step: int,
batch: dict[str, torch.Tensor],
outputs: MaskedLMOutput,
grad_norm: float,
lr: float,
):
"""Log a step to the logger and wandb.

Args:
step: The step number.
batch: The batch of data for the step.
outputs: The outputs of the step.
grad_norm: The gradient norm of the step.
lr: The learning rate of the step.
"""
num_tokens = batch["input_ids"].numel()
# 1 is the padding token for ESM-2.
num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != 1].numel()

self.min_loss = min(self.min_loss, outputs.loss.item())
assert self.grad_acc_step_count > 0, (
f"Gradient accumulation steps ({self.grad_acc_step_count}) must be greater than 0, and can be incremented by log_micro_step()."
)
self.min_loss = min(self.min_loss, self.running_loss / self.grad_acc_step_count)
step_time, self.previous_step_time = time.perf_counter() - self.previous_step_time, time.perf_counter()

self.metrics["train/loss"].update(outputs.loss)
self.metrics["train/loss"].update(self.running_loss / self.grad_acc_step_count)
self.metrics["train/learning_rate"].update(lr)
self.metrics["train/grad_norm"].update(grad_norm)
self.metrics["train/step_time"].update(step_time)
self.metrics["train/tokens_per_second"].update(num_tokens / step_time)
self.metrics["train/unpadded_tokens_per_second"].update(num_unpadded_tokens / step_time)

# Handle sequence packing for torchmetrics calculation.
if outputs.logits.dim() < 3:
outputs.logits = outputs.logits.unsqueeze(0)

self.metrics["train/perplexity"].update(outputs.logits, batch["labels"])
self.metrics["train/tokens_per_second"].update(self.num_tokens / step_time)
self.metrics["train/unpadded_tokens_per_second"].update(self.num_unpadded_tokens / step_time)

if step % self.logging_frequency == 0 and step > 0:
metrics = self.metrics.compute()
Expand All @@ -114,11 +117,17 @@ def log_step(
if self._dist_config.is_main_process():
wandb.log(metrics, step=step)
self._progress_bar.update(self.logging_frequency)
self._progress_bar.set_postfix({"loss": outputs.loss.item()})
self._progress_bar.set_postfix({"loss": self.running_loss / self.grad_acc_step_count})

if self._dist_config.local_rank == 0:
logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()]))

# Reset counters.
self.num_tokens = 0
self.num_unpadded_tokens = 0
self.running_loss = 0
self.grad_acc_step_count = 0

def finish(self):
"""Finish the logger and close the progress bar."""
if not self._dist_config.is_main_process():
Expand Down
93 changes: 54 additions & 39 deletions bionemo-recipes/recipes/esm2_native_te/train_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
from contextlib import nullcontext
from pathlib import Path

import hydra
Expand All @@ -37,12 +38,18 @@


@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2")
def main(args: DictConfig) -> float | None: # noqa: C901
def main(args: DictConfig) -> float | None:
"""Train ESM-2 with TE layers using ddp.

Returns:
float: The loss value for the final batch.
"""
# Validate arguments.
if not args.grad_acc_steps >= 1:
raise ValueError(
f"Gradient accumulation steps must be an integer greater than or equal to 1, but got: {args.grad_acc_steps}"
)

# Initialize the distributed configuration, including creating the distributed process group.
dist_config = DistributedConfig()
logger.info("Initializing distributed training: %s", dist_config)
Expand Down Expand Up @@ -115,49 +122,57 @@ def main(args: DictConfig) -> float | None: # noqa: C901

# Training loop
step = start_step
micro_step = 0
while step < args.num_train_steps:
for batch in train_dataloader:
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa PLW2901

# Forward pass with mixed precision.
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
outputs = model(**batch)

# Backward pass.
loss = outputs.loss
loss.backward()

# Compute and clip gradient norms.
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()

# Step optimizer.
optimizer.step()
scheduler.step()
optimizer.zero_grad()

perf_logger.log_step(
step=step,
batch=batch,
outputs=outputs,
grad_norm=total_norm,
lr=optimizer.param_groups[0]["lr"],
)

if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
save_checkpoint_ddp(
model=model,
optimizer=optimizer,
scheduler=scheduler,
ckpt_path=ckpt_path,
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901

micro_step += 1
with model.no_sync() if micro_step % args.grad_acc_steps != 0 else nullcontext():
# Forward pass with mixed precision.
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
outputs = model(**batch)

# Backward pass.
loss = outputs.loss / args.grad_acc_steps
loss.backward()

# Log microbatch step data for accumulation metrics.
perf_logger.log_micro_step(batch, outputs)

# Gradient accumulation.
if micro_step % args.grad_acc_steps == 0:
micro_step = 0

# Compute and clip gradient norms.
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()

# Step optimizer.
optimizer.step()
scheduler.step()
optimizer.zero_grad()

perf_logger.log_step(
step=step,
dist_config=dist_config,
dataloader=train_dataloader,
epoch=epoch,
grad_norm=total_norm,
lr=optimizer.param_groups[0]["lr"],
)

step += 1
if step >= args.num_train_steps:
break
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
save_checkpoint_ddp(
model=model,
optimizer=optimizer,
scheduler=scheduler,
ckpt_path=ckpt_path,
step=step,
dist_config=dist_config,
dataloader=train_dataloader,
epoch=epoch,
)

step += 1
if step >= args.num_train_steps:
break

# Dataloader exhausted, incrementing epoch
epoch += 1
Expand Down
73 changes: 43 additions & 30 deletions bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,18 @@


@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2")
def main(args: DictConfig) -> float | None: # noqa: C901
def main(args: DictConfig) -> float | None:
"""Train ESM-2 with TE layers using fsdp2.

Returns:
float: The loss value for the final batch.
"""
# Validate arguments.
if not args.grad_acc_steps >= 1:
raise ValueError(
f"Gradient accumulation steps must be an integer greater than or equal to 1, but got: {args.grad_acc_steps}"
)

# Initialize the distributed configuration, including creating the distributed process group.
dist_config = DistributedConfig()
logger.info("Initializing distributed training: %s", dist_config)
Expand Down Expand Up @@ -119,49 +125,56 @@ def main(args: DictConfig) -> float | None: # noqa: C901

# Training loop
step = start_step
micro_step = 0
while step < args.num_train_steps:
for batch in train_dataloader:
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901

micro_step += 1

# Forward pass with mixed precision.
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
outputs = model(**batch)

# Backward pass.
loss = outputs.loss
loss = outputs.loss / args.grad_acc_steps
loss.backward()

# Compute and clip gradient norms.
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()

# Step optimizer.
optimizer.step()
scheduler.step()
optimizer.zero_grad()

perf_logger.log_step(
step=step,
batch=batch,
outputs=outputs,
grad_norm=total_norm,
lr=optimizer.param_groups[0]["lr"],
)

if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
save_checkpoint_fsdp2(
model=model,
optimizer=optimizer,
scheduler=scheduler,
ckpt_path=ckpt_path,
# Log microbatch step data for accumulation metrics.
perf_logger.log_micro_step(batch, outputs)

if micro_step % args.grad_acc_steps == 0:
micro_step = 0

# Compute and clip gradient norms.
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()

# Step optimizer.
optimizer.step()
scheduler.step()
optimizer.zero_grad()

perf_logger.log_step(
step=step,
epoch=epoch,
dist_config=dist_config,
dataloader=train_dataloader,
grad_norm=total_norm,
lr=optimizer.param_groups[0]["lr"],
)

step += 1
if step >= args.num_train_steps:
break
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
save_checkpoint_fsdp2(
model=model,
optimizer=optimizer,
scheduler=scheduler,
ckpt_path=ckpt_path,
step=step,
epoch=epoch,
dist_config=dist_config,
dataloader=train_dataloader,
)

step += 1
if step >= args.num_train_steps:
break

# Dataloader exhausted, incrementing epoch
epoch += 1
Expand Down
Loading