-
Notifications
You must be signed in to change notification settings - Fork 101
Add gradient accumulation to ESM-2. #1254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughIntroduces gradient accumulation across the ESM2 Native TE recipes: adds Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Trainer
participant Model as Model (DDP/FSDP/M-FSDP)
participant Optim as Optimizer
participant Sched as Scheduler
participant Logger as PerfLogger
participant Ckpt as Checkpointer
rect rgba(235,245,255,0.6)
note over Trainer: Training loop with gradient accumulation
loop For each accumulation cycle
Trainer->>Trainer: micro_step = 1..grad_acc_steps
alt DDP non-boundary micro
Trainer->>Model: forward(batch) with no_sync + FP8 autocast
else Boundary or non-DDP/FSDP
Trainer->>Model: forward(batch) with FP8 autocast
end
Model-->>Trainer: outputs (loss, logits)
Trainer->>Trainer: scaled_loss = outputs.loss / grad_acc_steps
Trainer->>Model: backward(scaled_loss)
Trainer->>Logger: log_micro_step(batch, outputs)
opt micro_step < grad_acc_steps
Trainer-->>Trainer: continue accumulating
end
end
end
rect rgba(240,255,240,0.6)
note over Trainer,Optim: Accumulation boundary
Trainer->>Optim: clip_grad_norm_
Trainer->>Optim: step()
Trainer->>Sched: step()
Trainer->>Optim: zero_grad()
Trainer->>Logger: log_step(step, grad_norm, lr)
Trainer->>Ckpt: save_if_needed(step, epoch, dist_config, dataloader)
Trainer->>Trainer: step += 1
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (1)
43-43: Fix CI: C901 'main' too complexSilence the lint error to unblock CI, or refactor. Minimal patch:
-def main(args: DictConfig) -> float | None: +def main(args: DictConfig) -> float | None: # noqa: C901
🧹 Nitpick comments (7)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)
76-86: Avoid hardcoded PAD id; derive unpadded tokens from labels/attention_maskUse labels (ignore_index=-100) or attention_mask to compute unpadded tokens; hardcoding 1 risks miscounts across tokenizers.
def log_micro_step(self, batch, outputs): @@ - self.num_unpadded_tokens += batch["input_ids"][batch["input_ids"] != 1].numel() + if "attention_mask" in batch: + # Fast path when available + self.num_unpadded_tokens += batch["attention_mask"].sum().item() + else: + # Fall back to labels mask aligned with ignore_index used by Perplexity + self.num_unpadded_tokens += (batch["labels"] != -100).sum().item()
101-110: Guard against zero micro-steps and clamp step_timePrevent division by zero if log_step is called without prior micro-steps and handle near-zero timers to avoid inf TPS.
- self.min_loss = min(self.min_loss, self.running_loss / self.grad_acc_steps) - step_time, self.previous_step_time = time.perf_counter() - self.previous_step_time, time.perf_counter() + if self.grad_acc_steps == 0: + logger.warning("log_step called with no accumulated micro-steps; skipping log.") + return + self.min_loss = min(self.min_loss, self.running_loss / self.grad_acc_steps) + step_time, self.previous_step_time = time.perf_counter() - self.previous_step_time, time.perf_counter() + if step_time <= 0: + step_time = 1e-6 @@ - self.metrics["train/tokens_per_second"].update(self.num_tokens / step_time) - self.metrics["train/unpadded_tokens_per_second"].update(self.num_unpadded_tokens / step_time) + self.metrics["train/tokens_per_second"].update(self.num_tokens / step_time) + self.metrics["train/unpadded_tokens_per_second"].update(self.num_unpadded_tokens / step_time)bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml (1)
4-4: Add runtime validation for grad_acc_steps and confirm mFSDP sync settingDefault 1 is fine. Validate at runtime that grad_acc_steps >= 1, and please confirm whether fully_shard_kwargs.sync_model_each_microbatch should be disabled when grad_acc_steps > 1 to avoid redundant grad syncs in mFSDP.
bionemo-recipes/recipes/esm2_native_te/train_ddp.py (1)
117-120: Validate grad_acc_steps at startupAdd a quick check to fail fast on invalid configs.
# Near the start of main(), after args are available if args.grad_acc_steps < 1: raise ValueError(f"grad_acc_steps must be >= 1, got {args.grad_acc_steps}")bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py (1)
121-124: Validate grad_acc_steps at startupAdd a guard to enforce grad_acc_steps >= 1.
# Near the start of main(), after args are available if args.grad_acc_steps < 1: raise ValueError(f"grad_acc_steps must be >= 1, got {args.grad_acc_steps}")bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (2)
143-151: Verify interaction with sync_model_each_microbatchUsing model.sync() only on the boundary is the right pattern. Please confirm fully_shard_kwargs.sync_model_each_microbatch is false when grad_acc_steps > 1; otherwise you may still sync every micro-batch.
Optionally add a guard:
if args.grad_acc_steps > 1 and getattr(args.fully_shard_kwargs, "sync_model_each_microbatch", False): logger.warning("grad_acc_steps > 1 with sync_model_each_microbatch=True may hurt performance.")
136-139: Validate grad_acc_steps at startupAdd a guard to enforce grad_acc_steps >= 1.
# Near the start of main(), after args are available if args.grad_acc_steps < 1: raise ValueError(f"grad_acc_steps must be >= 1, got {args.grad_acc_steps}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml(1 hunks)bionemo-recipes/recipes/esm2_native_te/perf_logger.py(2 hunks)bionemo-recipes/recipes/esm2_native_te/train_ddp.py(2 hunks)bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py(2 hunks)bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py (2)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)
log_micro_step(76-85)log_step(87-128)bionemo-recipes/recipes/esm2_native_te/checkpoint.py (2)
should_save_checkpoint(78-82)save_checkpoint_fsdp2(410-454)
bionemo-recipes/recipes/esm2_native_te/train_ddp.py (2)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)
log_micro_step(76-85)log_step(87-128)bionemo-recipes/recipes/esm2_native_te/checkpoint.py (2)
should_save_checkpoint(78-82)save_checkpoint_ddp(125-158)
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (2)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)
log_micro_step(76-85)log_step(87-128)bionemo-recipes/recipes/esm2_native_te/checkpoint.py (2)
should_save_checkpoint(78-82)save_checkpoint_mfsdp(237-280)
🪛 GitHub Actions: BioNeMo Framework CI
bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py
[warning] Ruff formatting changes detected in related files (files reformatted by ruff).
bionemo-recipes/recipes/esm2_native_te/train_ddp.py
[warning] Ruff formatting changes detected in related files (files reformatted by ruff).
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py
[error] 43-43: C901: 'main' is too complex (11 > 10).
bionemo-recipes/recipes/esm2_native_te/perf_logger.py
[error] 1-1: Trailing whitespace detected and removed by pre-commit hooks.
[warning] 97-98: Ruff: code block formatting changes were applied.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Analyze (rust)
🔇 Additional comments (3)
bionemo-recipes/recipes/esm2_native_te/train_ddp.py (2)
124-137: DDP accumulation pattern looks correctno_sync on non-boundary micro-steps + loss scaling + micro logging is solid.
150-154: Pass lr/grad_norm at accumulation boundary only (good)Logging only on update steps keeps metrics aligned to effective steps.
Ensure grad clipping threshold (1.0) is intended; adjust if different model sizes need another norm.
bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py (1)
141-172: Confirm FSDP2 gradient sync behavior under accumulationNo no_sync context is used here. With loss scaling, correctness holds, but you may incur extra allreduces per micro-batch. Please confirm FSDP2’s fully_shard() wrapper does not need an explicit no_sync (or similar) for accumulation efficiency; if available, wrap forward/backward with it for non-boundary micro-steps.
2f6b3d2 to
2afdefe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (5)
bionemo-recipes/recipes/esm2_native_te/.ruff.toml (1)
2-2: Scope the C901 suppression more narrowly.Let’s avoid blanketing the entire recipe with a global C901 ignore; it makes it easy for future complex functions to slip in unnoticed. Please narrow the suppression (e.g., inline
# noqa: C901, per-fileignore, or bumpingmccabe.max-complexityonly where justified).bionemo-recipes/recipes/esm2_native_te/train_ddp.py (1)
145-147: Minor: zero_grad(set_to_none=True) for perf/memory.Optional optimization to reduce memory writes and improve perf.
- optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True)bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (1)
164-167: Minor: zero_grad(set_to_none=True).Small perf tweak.
- optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True)bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)
75-85: Avoid hard-coded PAD id and in-place logits mutation.
- Counting unpadded tokens via input_ids != 1 assumes pad_token_id=1. Prefer attention_mask if available.
- Don’t mutate outputs.logits; use a local var to adjust dims before update.
- self.num_unpadded_tokens += batch["input_ids"][batch["input_ids"] != 1].numel() - self.running_loss += outputs.loss.item() - # Handle sequence packing for torchmetrics calculation. - if outputs.logits.dim() < 3: - outputs.logits = outputs.logits.unsqueeze(0) - self.metrics["train/perplexity"].update(outputs.logits, batch["labels"]) + if "attention_mask" in batch: + # Preferred: rely on attention_mask if provided + self.num_unpadded_tokens += batch["attention_mask"].sum().item() + else: + # Fallback: count non-pad tokens; try to read pad_token_id from batch/config if present, else assume 1 + pad_id = batch.get("pad_token_id", 1) if isinstance(batch, dict) else 1 + self.num_unpadded_tokens += (batch["input_ids"] != pad_id).sum().item() + self.running_loss += outputs.loss.item() + # Handle sequence packing for torchmetrics calculation without mutating outputs + logits = outputs.logits if outputs.logits.dim() >= 3 else outputs.logits.unsqueeze(0) + self.metrics["train/perplexity"].update(logits, batch["labels"])
117-117: Progress bar postfix uses pre-reset loss: OK but consider clarity.Optional: compute once into a local variable to avoid recomputation and ensure clarity.
- self._progress_bar.set_postfix({"loss": self.running_loss / self.grad_acc_steps}) + avg_loss = self.running_loss / self.grad_acc_steps + self._progress_bar.set_postfix({"loss": avg_loss})
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
bionemo-recipes/recipes/esm2_native_te/.ruff.toml(1 hunks)bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml(1 hunks)bionemo-recipes/recipes/esm2_native_te/perf_logger.py(2 hunks)bionemo-recipes/recipes/esm2_native_te/train_ddp.py(3 hunks)bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py(2 hunks)bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml
- bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py
🧰 Additional context used
🧬 Code graph analysis (2)
bionemo-recipes/recipes/esm2_native_te/train_ddp.py (3)
bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py (1)
main(44-189)bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)
log_micro_step(75-84)log_step(86-126)bionemo-recipes/recipes/esm2_native_te/checkpoint.py (2)
should_save_checkpoint(78-82)save_checkpoint_ddp(125-158)
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (2)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)
log_micro_step(75-84)log_step(86-126)bionemo-recipes/recipes/esm2_native_te/checkpoint.py (2)
should_save_checkpoint(78-82)save_checkpoint_mfsdp(237-280)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Analyze (rust)
🔇 Additional comments (4)
bionemo-recipes/recipes/esm2_native_te/train_ddp.py (1)
119-136: Good GA integration: correct no_sync cadence and loss scaling.
- micro_step gating with DDP no_sync is correct.
- FP8 autocast scope + per-micro backward with loss/grad_acc_steps looks right.
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)
70-74: Initialize accumulation counters: LGTM.Counters are correctly initialized for accumulation.
122-127: Counter resets: LGTM.Resets are correctly scoped to accumulation window end.
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (1)
143-155: GA flow looks correct; verifymodel.sync()semantics
- Micro-step accumulation and loss scaling look good.
- Confirm that your mFSDP-wrapped model provides a
sync()context manager that triggers gradient synchronization only on accumulation boundaries; otherwise use the appropriate API (e.g.,no_sync()) to skip sync on other micro-steps.
Signed-off-by: Cory Ye <[email protected]>
2afdefe to
9d269d4
Compare
Description
grad_acc_stepsconfig parameter for activating gradient accumulation in DDP, FSDP2, and MFSDP.Usage
Given:
we have:
Known Issues
Loss Curves
https://api.wandb.ai/links/nvidia/99tuw05t

Type of changes
CI Pipeline Configuration
Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.
Unit tests marked as
@pytest.mark.multi_gpuor@pytest.mark.distributedare not run in the PR pipeline.For more details, see CONTRIBUTING
Note
By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.
Authorizing CI Runs
We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.
automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
/ok to testcomment on the pull request to trigger CI. This will need to be done for each new commit.Pre-submit Checklist
Summary by CodeRabbit
New Features
Improvements
Configuration
Chores