Skip to content

Commit 2f6b3d2

Browse files
committed
Add gradient accumulation to ESM-2.
Signed-off-by: Cory Ye <[email protected]>
1 parent 8ff2e4b commit 2f6b3d2

File tree

5 files changed

+158
-119
lines changed

5 files changed

+158
-119
lines changed

bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Training config
22
model_tag: ??? # E.g., nvidia/esm2_t6_8M_UR50D, facebook/esm2_t6_8M_UR50D, or a local path (e.g ./example_8m_checkpoint)
33
num_train_steps: ???
4+
grad_acc_steps: 1
45

56
# TODO: Once BIONEMO-2583 and BIONEMO-2719 are fixed, enable this by default and simplify training scripts to remove the
67
# meta-device conditional.

bionemo-recipes/recipes/esm2_native_te/perf_logger.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -68,43 +68,45 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
6868
# Log the entire args object to wandb for experiment tracking and reproducibility.
6969
wandb.init(**args.wandb_init_args, config=self._run_config)
7070
self._progress_bar = tqdm(total=args.num_train_steps, desc="Training")
71+
self.num_tokens = 0
72+
self.num_unpadded_tokens = 0
73+
self.running_loss = 0
74+
self.grad_acc_steps = 0
75+
76+
def log_micro_step(self, batch, outputs):
77+
"""Store data on micro step for accumulation metrics."""
78+
self.grad_acc_steps += 1
79+
self.num_tokens += batch["input_ids"].numel()
80+
self.num_unpadded_tokens += batch["input_ids"][batch["input_ids"] != 1].numel()
81+
self.running_loss += outputs.loss.item()
82+
# Handle sequence packing for torchmetrics calculation.
83+
if outputs.logits.dim() < 3:
84+
outputs.logits = outputs.logits.unsqueeze(0)
85+
self.metrics["train/perplexity"].update(outputs.logits, batch["labels"])
7186

7287
def log_step(
7388
self,
7489
step: int,
75-
batch: dict[str, torch.Tensor],
76-
outputs: MaskedLMOutput,
7790
grad_norm: float,
7891
lr: float,
7992
):
8093
"""Log a step to the logger and wandb.
8194
8295
Args:
8396
step: The step number.
84-
batch: The batch of data for the step.
85-
outputs: The outputs of the step.
8697
grad_norm: The gradient norm of the step.
8798
lr: The learning rate of the step.
8899
"""
89-
num_tokens = batch["input_ids"].numel()
90-
# 1 is the padding token for ESM-2.
91-
num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != 1].numel()
92100

93-
self.min_loss = min(self.min_loss, outputs.loss.item())
101+
self.min_loss = min(self.min_loss, self.running_loss / self.grad_acc_steps)
94102
step_time, self.previous_step_time = time.perf_counter() - self.previous_step_time, time.perf_counter()
95103

96-
self.metrics["train/loss"].update(outputs.loss)
104+
self.metrics["train/loss"].update(self.running_loss / self.grad_acc_steps)
97105
self.metrics["train/learning_rate"].update(lr)
98106
self.metrics["train/grad_norm"].update(grad_norm)
99107
self.metrics["train/step_time"].update(step_time)
100-
self.metrics["train/tokens_per_second"].update(num_tokens / step_time)
101-
self.metrics["train/unpadded_tokens_per_second"].update(num_unpadded_tokens / step_time)
102-
103-
# Handle sequence packing for torchmetrics calculation.
104-
if outputs.logits.dim() < 3:
105-
outputs.logits = outputs.logits.unsqueeze(0)
106-
107-
self.metrics["train/perplexity"].update(outputs.logits, batch["labels"])
108+
self.metrics["train/tokens_per_second"].update(self.num_tokens / step_time)
109+
self.metrics["train/unpadded_tokens_per_second"].update(self.num_unpadded_tokens / step_time)
108110

109111
if step % self.logging_frequency == 0 and step > 0:
110112
metrics = self.metrics.compute()
@@ -114,10 +116,16 @@ def log_step(
114116
if self._dist_config.is_main_process():
115117
wandb.log(metrics, step=step)
116118
self._progress_bar.update(self.logging_frequency)
117-
self._progress_bar.set_postfix({"loss": outputs.loss.item()})
119+
self._progress_bar.set_postfix({"loss": self.running_loss / self.grad_acc_steps})
118120

119121
if self._dist_config.local_rank == 0:
120122
logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()]))
123+
124+
# Reset counters.
125+
self.num_tokens = 0
126+
self.num_unpadded_tokens = 0
127+
self.running_loss = 0
128+
self.grad_acc_steps = 0
121129

122130
def finish(self):
123131
"""Finish the logger and close the progress bar."""

bionemo-recipes/recipes/esm2_native_te/train_ddp.py

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import hydra
2020
import torch
21+
from contextlib import nullcontext
2122
import transformer_engine.pytorch
2223
from omegaconf import DictConfig
2324
from torch.distributed.device_mesh import init_device_mesh
@@ -115,49 +116,58 @@ def main(args: DictConfig) -> float | None: # noqa: C901
115116

116117
# Training loop
117118
step = start_step
119+
micro_step = 0
118120
while step < args.num_train_steps:
119121
for batch in train_dataloader:
120122
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa PLW2901
121123

122-
# Forward pass with mixed precision.
123-
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
124-
outputs = model(**batch)
125-
126-
# Backward pass.
127-
loss = outputs.loss
128-
loss.backward()
129-
130-
# Compute and clip gradient norms.
131-
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
132-
133-
# Step optimizer.
134-
optimizer.step()
135-
scheduler.step()
136-
optimizer.zero_grad()
137-
138-
perf_logger.log_step(
139-
step=step,
140-
batch=batch,
141-
outputs=outputs,
142-
grad_norm=total_norm,
143-
lr=optimizer.param_groups[0]["lr"],
144-
)
145-
146-
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
147-
save_checkpoint_ddp(
148-
model=model,
149-
optimizer=optimizer,
150-
scheduler=scheduler,
151-
ckpt_path=ckpt_path,
124+
micro_step += 1
125+
with model.no_sync() if micro_step % args.grad_acc_steps != 0 else nullcontext():
126+
127+
# Forward pass with mixed precision.
128+
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
129+
outputs = model(**batch)
130+
131+
# Backward pass.
132+
loss = outputs.loss / args.grad_acc_steps
133+
loss.backward()
134+
135+
# Log microbatch step data for accumulation metrics.
136+
perf_logger.log_micro_step(batch, outputs)
137+
138+
# Gradient accumulation.
139+
if micro_step % args.grad_acc_steps == 0:
140+
micro_step = 0
141+
142+
# Compute and clip gradient norms.
143+
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
144+
145+
# Step optimizer.
146+
optimizer.step()
147+
scheduler.step()
148+
optimizer.zero_grad()
149+
150+
perf_logger.log_step(
152151
step=step,
153-
dist_config=dist_config,
154-
dataloader=train_dataloader,
155-
epoch=epoch,
152+
grad_norm=total_norm,
153+
lr=optimizer.param_groups[0]["lr"],
156154
)
157155

158-
step += 1
159-
if step >= args.num_train_steps:
160-
break
156+
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
157+
save_checkpoint_ddp(
158+
model=model,
159+
optimizer=optimizer,
160+
scheduler=scheduler,
161+
ckpt_path=ckpt_path,
162+
step=step,
163+
dist_config=dist_config,
164+
dataloader=train_dataloader,
165+
epoch=epoch,
166+
)
167+
168+
step += 1
169+
if step >= args.num_train_steps:
170+
break
161171

162172
# Dataloader exhausted, incrementing epoch
163173
epoch += 1

bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import hydra
2020
import torch
2121
import transformer_engine.pytorch
22+
from contextlib import nullcontext
2223
from omegaconf import DictConfig, OmegaConf
2324
from torch.distributed.device_mesh import init_device_mesh
2425
from torch.distributed.fsdp import fully_shard
@@ -119,49 +120,56 @@ def main(args: DictConfig) -> float | None: # noqa: C901
119120

120121
# Training loop
121122
step = start_step
123+
micro_step = 0
122124
while step < args.num_train_steps:
123125
for batch in train_dataloader:
124126
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901
125127

128+
micro_step += 1
129+
126130
# Forward pass with mixed precision.
127131
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
128132
outputs = model(**batch)
129133

130134
# Backward pass.
131-
loss = outputs.loss
135+
loss = outputs.loss / args.grad_acc_steps
132136
loss.backward()
133137

134-
# Compute and clip gradient norms.
135-
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
136-
137-
# Step optimizer.
138-
optimizer.step()
139-
scheduler.step()
140-
optimizer.zero_grad()
141-
142-
perf_logger.log_step(
143-
step=step,
144-
batch=batch,
145-
outputs=outputs,
146-
grad_norm=total_norm,
147-
lr=optimizer.param_groups[0]["lr"],
148-
)
149-
150-
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
151-
save_checkpoint_fsdp2(
152-
model=model,
153-
optimizer=optimizer,
154-
scheduler=scheduler,
155-
ckpt_path=ckpt_path,
138+
# Log microbatch step data for accumulation metrics.
139+
perf_logger.log_micro_step(batch, outputs)
140+
141+
if micro_step % args.grad_acc_steps == 0:
142+
micro_step = 0
143+
144+
# Compute and clip gradient norms.
145+
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
146+
147+
# Step optimizer.
148+
optimizer.step()
149+
scheduler.step()
150+
optimizer.zero_grad()
151+
152+
perf_logger.log_step(
156153
step=step,
157-
epoch=epoch,
158-
dist_config=dist_config,
159-
dataloader=train_dataloader,
154+
grad_norm=total_norm,
155+
lr=optimizer.param_groups[0]["lr"],
160156
)
161157

162-
step += 1
163-
if step >= args.num_train_steps:
164-
break
158+
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
159+
save_checkpoint_fsdp2(
160+
model=model,
161+
optimizer=optimizer,
162+
scheduler=scheduler,
163+
ckpt_path=ckpt_path,
164+
step=step,
165+
epoch=epoch,
166+
dist_config=dist_config,
167+
dataloader=train_dataloader,
168+
)
169+
170+
step += 1
171+
if step >= args.num_train_steps:
172+
break
165173

166174
# Dataloader exhausted, incrementing epoch
167175
epoch += 1

bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
import transformer_engine.pytorch
2222
import transformers
23+
from contextlib import nullcontext
2324
from megatron_fsdp.fully_shard import fully_shard
2425
from omegaconf import DictConfig, OmegaConf
2526
from torch.distributed.device_mesh import init_device_mesh
@@ -134,49 +135,60 @@ def main(args: DictConfig) -> float | None:
134135

135136
# Training loop
136137
step = start_step
138+
micro_step = 0
137139
while step < args.num_train_steps:
138140
for batch in train_dataloader:
139141
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901
140142

141-
# Forward pass with mixed precision.
142-
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
143-
outputs = model(**batch)
144-
145-
# Backward pass.
146-
loss = outputs.loss
147-
loss.backward()
148-
149-
# Compute and clip gradient norms.
150-
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
151-
152-
# Step optimizer.
153-
optimizer.step()
154-
scheduler.step()
155-
optimizer.zero_grad()
156-
157-
perf_logger.log_step(
158-
step=step,
159-
batch=batch,
160-
outputs=outputs,
161-
grad_norm=total_norm,
162-
lr=optimizer.param_groups[0]["lr"],
163-
)
164-
165-
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
166-
save_checkpoint_mfsdp(
167-
model=model,
168-
optimizer=optimizer,
169-
scheduler=scheduler,
170-
ckpt_path=ckpt_path,
143+
micro_step += 1
144+
with model.sync() if micro_step % args.grad_acc_steps == 0 else nullcontext():
145+
146+
# Forward pass with mixed precision.
147+
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
148+
outputs = model(**batch)
149+
150+
# Backward pass.
151+
loss = outputs.loss / args.grad_acc_steps
152+
loss.backward()
153+
154+
# Log microbatch step data for accumulation metrics.
155+
perf_logger.log_micro_step(batch, outputs)
156+
157+
158+
# Gradient accumulation.
159+
if micro_step % args.grad_acc_steps == 0:
160+
micro_step = 0
161+
162+
# Compute and clip gradient norms.
163+
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
164+
165+
# Step optimizer.
166+
optimizer.step()
167+
scheduler.step()
168+
optimizer.zero_grad()
169+
170+
perf_logger.log_step(
171171
step=step,
172-
dist_config=dist_config,
173-
dataloader=train_dataloader,
174-
epoch=epoch,
172+
grad_norm=total_norm,
173+
lr=optimizer.param_groups[0]["lr"],
175174
)
176175

177-
step += 1
178-
if step >= args.num_train_steps:
179-
break
176+
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
177+
save_checkpoint_mfsdp(
178+
model=model,
179+
optimizer=optimizer,
180+
scheduler=scheduler,
181+
ckpt_path=ckpt_path,
182+
step=step,
183+
dist_config=dist_config,
184+
dataloader=train_dataloader,
185+
epoch=epoch,
186+
)
187+
188+
step += 1
189+
if step >= args.num_train_steps:
190+
break
191+
180192
# Dataloader exhausted, incrementing epoch
181193
epoch += 1
182194
dataset_or_sampler.set_epoch(epoch)

0 commit comments

Comments
 (0)