Skip to content

Conversation

@cspades
Copy link
Member

@cspades cspades commented Oct 14, 2025

Description

  • Minor modifications to support a new grad_acc_steps config parameter for activating gradient accumulation in DDP, FSDP2, and MFSDP.

Usage

Given:

# World Size: 8
grad_acc_steps: 2
micro_batch_size: 4

we have:

# Effective Global Batch Size (WorldSize x MBS x GradAcc): 64
# Per-Step Batch Size (WorldSize x MBS): 32
# Per-Rank Effective Batch Size (MBS x GradAcc): 8
# Per-Rank Per-Step Batch Size (MBS): 4

Known Issues

  • Megatron-FSDP gradients don't match FSDP2 / DDP gradients. Will investigate separately.
  • May need to normalize the gradient norm for logging, since FSDP2 and Megatron-FSDP shard the gradient on every rank.

Loss Curves

  • DDP without gradient accumulation (where we call the optimizer and LR schedulers every microbatch) is the baseline, and DDP and FSDP2 with 64 gradient accumulation steps (maintaining the same effective batch size by decreasing the micro-batch size) have the same loss curve when reporting every optimization step.

https://api.wandb.ai/links/nvidia/99tuw05t
Screenshot 2025-10-14 at 12 55 14 PM

Type of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Refactor
  • Documentation update
  • Other (please describe):

CI Pipeline Configuration

Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.

  • ciflow:skip - Skip all CI tests for this PR
  • ciflow:notebooks - Run Jupyter notebooks execution tests for bionemo2
  • ciflow:slow - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2
  • ciflow:all - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2.
  • ciflow:all-recipes - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes.

Unit tests marked as @pytest.mark.multi_gpu or @pytest.mark.distributed are not run in the PR pipeline.

For more details, see CONTRIBUTING

Note

By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.

Authorizing CI Runs

We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.

  • If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will
    automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
  • If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an
    /ok to test comment on the pull request to trigger CI. This will need to be done for each new commit.

Pre-submit Checklist

  • I have tested these changes locally
  • I have updated the documentation accordingly
  • I have added/updated tests as needed
  • All existing tests pass successfully

Summary by CodeRabbit

  • New Features

    • Gradient accumulation with micro-batching across training modes, controlled by grad_acc_steps.
  • Improvements

    • Per-micro-step metric aggregation for loss, perplexity, and token throughput; progress bar and logs show accumulated values.
    • Checkpointing, gradient clipping, optimizer and scheduler updates occur at accumulation boundaries.
    • Input validation ensures grad_acc_steps >= 1.
  • Configuration

    • New grad_acc_steps option (default: 1).
  • Chores

    • Lint rule C901 ignored for this recipe.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 14, 2025

Walkthrough

Introduces gradient accumulation across the ESM2 Native TE recipes: adds grad_acc_steps config, updates training loops (DDP, FSDP2, M-FSDP) to micro-batch, scale loss, defer optimizer/ckpt to accumulation boundaries, and revises PerfLogger to accumulate micro-step metrics with a new log_micro_step and a simplified log_step. Adds Ruff ignore for C901.

Changes

Cohort / File(s) Summary
Config: grad accumulation parameter
bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml
Adds grad_acc_steps: 1 to enable configuring gradient accumulation.
Performance logging: accumulation-aware
bionemo-recipes/recipes/esm2_native_te/perf_logger.py
Adds accumulation fields (num_tokens, num_unpadded_tokens, running_loss, grad_acc_step_count), new log_micro_step(batch, outputs) to collect per-micro-step metrics, and changes log_step(step, grad_norm, lr) to consume aggregated metrics (computes averaged loss, tokens/sec, resets accumulators, updates progress/W&B). Preserves perplexity handling and normalizes logits where needed.
DDP training loop: gradient accumulation
bionemo-recipes/recipes/esm2_native_te/train_ddp.py
Implements micro-batching with grad_acc_steps: validates arg, tracks micro_step, uses model.no_sync() for non-boundary micros, scales loss per micro-batch, calls perf_logger.log_micro_step, and performs gradient clipping / optimizer.step() / scheduler.step() / zero_grad() and checkpointing only at accumulation boundaries; advances step on boundaries.
FSDP2 training loop: gradient accumulation
bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py
Validates grad_acc_steps >= 1, adds micro_step tracking, scales loss by 1/grad_acc_steps, logs micro-steps via perf_logger.log_micro_step, defers gradient updates and checkpoint saves to accumulation boundaries, and calls perf_logger.log_step at boundaries.
M-FSDP training loop: gradient accumulation
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py
Adds micro_step and accumulation cadence, uses model.sync() / nullcontext() around micros, scales loss, logs micro-steps, executes gradient clipping / optimizer / scheduler / zero and checkpointing on accumulation boundaries, and increments step only on boundary.
Lint config
bionemo-recipes/recipes/esm2_native_te/.ruff.toml
Adds ignore for Ruff rule C901.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Trainer
    participant Model as Model (DDP/FSDP/M-FSDP)
    participant Optim as Optimizer
    participant Sched as Scheduler
    participant Logger as PerfLogger
    participant Ckpt as Checkpointer

    rect rgba(235,245,255,0.6)
    note over Trainer: Training loop with gradient accumulation
    loop For each accumulation cycle
        Trainer->>Trainer: micro_step = 1..grad_acc_steps
        alt DDP non-boundary micro
            Trainer->>Model: forward(batch) with no_sync + FP8 autocast
        else Boundary or non-DDP/FSDP
            Trainer->>Model: forward(batch) with FP8 autocast
        end
        Model-->>Trainer: outputs (loss, logits)
        Trainer->>Trainer: scaled_loss = outputs.loss / grad_acc_steps
        Trainer->>Model: backward(scaled_loss)
        Trainer->>Logger: log_micro_step(batch, outputs)
        opt micro_step < grad_acc_steps
            Trainer-->>Trainer: continue accumulating
        end
    end
    end

    rect rgba(240,255,240,0.6)
    note over Trainer,Optim: Accumulation boundary
    Trainer->>Optim: clip_grad_norm_
    Trainer->>Optim: step()
    Trainer->>Sched: step()
    Trainer->>Optim: zero_grad()
    Trainer->>Logger: log_step(step, grad_norm, lr)
    Trainer->>Ckpt: save_if_needed(step, epoch, dist_config, dataloader)
    Trainer->>Trainer: step += 1
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I nibble bytes between the steps, so small, so neat,
A dozen hops before I land with steady feet.
I stash my seeds—loss, tokens—tucked away,
Then bound at once when gradients say “okay.”
Carrots cached, I checkpoint, blink—then leap! 🥕🐇

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Title Check ✅ Passed The title concisely and accurately summarizes the primary change by stating the addition of gradient accumulation support to ESM-2, clearly reflecting the main purpose of the pull request without extraneous detail.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed This pull request description follows the repository’s template by providing a clear Description section, a Usage subsection with code snippets, a Type of changes section with the correct option selected, and the CI Pipeline Configuration, Authorizing CI Runs, and Pre-submit Checklist headings. It also includes optional Known Issues and Loss Curves sections without disrupting the required structure. All mandatory template headings and information are present and well formatted.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch cye/bnmo-recipe-gradacc

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (1)

43-43: Fix CI: C901 'main' too complex

Silence the lint error to unblock CI, or refactor. Minimal patch:

-def main(args: DictConfig) -> float | None:
+def main(args: DictConfig) -> float | None:  # noqa: C901
🧹 Nitpick comments (7)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)

76-86: Avoid hardcoded PAD id; derive unpadded tokens from labels/attention_mask

Use labels (ignore_index=-100) or attention_mask to compute unpadded tokens; hardcoding 1 risks miscounts across tokenizers.

 def log_micro_step(self, batch, outputs):
@@
-        self.num_unpadded_tokens += batch["input_ids"][batch["input_ids"] != 1].numel()
+        if "attention_mask" in batch:
+            # Fast path when available
+            self.num_unpadded_tokens += batch["attention_mask"].sum().item()
+        else:
+            # Fall back to labels mask aligned with ignore_index used by Perplexity
+            self.num_unpadded_tokens += (batch["labels"] != -100).sum().item()

101-110: Guard against zero micro-steps and clamp step_time

Prevent division by zero if log_step is called without prior micro-steps and handle near-zero timers to avoid inf TPS.

-        self.min_loss = min(self.min_loss, self.running_loss / self.grad_acc_steps)
-        step_time, self.previous_step_time = time.perf_counter() - self.previous_step_time, time.perf_counter()
+        if self.grad_acc_steps == 0:
+            logger.warning("log_step called with no accumulated micro-steps; skipping log.")
+            return
+        self.min_loss = min(self.min_loss, self.running_loss / self.grad_acc_steps)
+        step_time, self.previous_step_time = time.perf_counter() - self.previous_step_time, time.perf_counter()
+        if step_time <= 0:
+            step_time = 1e-6
@@
-        self.metrics["train/tokens_per_second"].update(self.num_tokens / step_time)
-        self.metrics["train/unpadded_tokens_per_second"].update(self.num_unpadded_tokens / step_time)
+        self.metrics["train/tokens_per_second"].update(self.num_tokens / step_time)
+        self.metrics["train/unpadded_tokens_per_second"].update(self.num_unpadded_tokens / step_time)
bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml (1)

4-4: Add runtime validation for grad_acc_steps and confirm mFSDP sync setting

Default 1 is fine. Validate at runtime that grad_acc_steps >= 1, and please confirm whether fully_shard_kwargs.sync_model_each_microbatch should be disabled when grad_acc_steps > 1 to avoid redundant grad syncs in mFSDP.

bionemo-recipes/recipes/esm2_native_te/train_ddp.py (1)

117-120: Validate grad_acc_steps at startup

Add a quick check to fail fast on invalid configs.

# Near the start of main(), after args are available
if args.grad_acc_steps < 1:
    raise ValueError(f"grad_acc_steps must be >= 1, got {args.grad_acc_steps}")
bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py (1)

121-124: Validate grad_acc_steps at startup

Add a guard to enforce grad_acc_steps >= 1.

# Near the start of main(), after args are available
if args.grad_acc_steps < 1:
    raise ValueError(f"grad_acc_steps must be >= 1, got {args.grad_acc_steps}")
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (2)

143-151: Verify interaction with sync_model_each_microbatch

Using model.sync() only on the boundary is the right pattern. Please confirm fully_shard_kwargs.sync_model_each_microbatch is false when grad_acc_steps > 1; otherwise you may still sync every micro-batch.

Optionally add a guard:

if args.grad_acc_steps > 1 and getattr(args.fully_shard_kwargs, "sync_model_each_microbatch", False):
    logger.warning("grad_acc_steps > 1 with sync_model_each_microbatch=True may hurt performance.")

136-139: Validate grad_acc_steps at startup

Add a guard to enforce grad_acc_steps >= 1.

# Near the start of main(), after args are available
if args.grad_acc_steps < 1:
    raise ValueError(f"grad_acc_steps must be >= 1, got {args.grad_acc_steps}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 35d2422 and 2f6b3d2.

📒 Files selected for processing (5)
  • bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml (1 hunks)
  • bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2 hunks)
  • bionemo-recipes/recipes/esm2_native_te/train_ddp.py (2 hunks)
  • bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py (2 hunks)
  • bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py (2)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)
  • log_micro_step (76-85)
  • log_step (87-128)
bionemo-recipes/recipes/esm2_native_te/checkpoint.py (2)
  • should_save_checkpoint (78-82)
  • save_checkpoint_fsdp2 (410-454)
bionemo-recipes/recipes/esm2_native_te/train_ddp.py (2)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)
  • log_micro_step (76-85)
  • log_step (87-128)
bionemo-recipes/recipes/esm2_native_te/checkpoint.py (2)
  • should_save_checkpoint (78-82)
  • save_checkpoint_ddp (125-158)
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (2)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)
  • log_micro_step (76-85)
  • log_step (87-128)
bionemo-recipes/recipes/esm2_native_te/checkpoint.py (2)
  • should_save_checkpoint (78-82)
  • save_checkpoint_mfsdp (237-280)
🪛 GitHub Actions: BioNeMo Framework CI
bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py

[warning] Ruff formatting changes detected in related files (files reformatted by ruff).

bionemo-recipes/recipes/esm2_native_te/train_ddp.py

[warning] Ruff formatting changes detected in related files (files reformatted by ruff).

bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py

[error] 43-43: C901: 'main' is too complex (11 > 10).

bionemo-recipes/recipes/esm2_native_te/perf_logger.py

[error] 1-1: Trailing whitespace detected and removed by pre-commit hooks.


[warning] 97-98: Ruff: code block formatting changes were applied.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Analyze (rust)
🔇 Additional comments (3)
bionemo-recipes/recipes/esm2_native_te/train_ddp.py (2)

124-137: DDP accumulation pattern looks correct

no_sync on non-boundary micro-steps + loss scaling + micro logging is solid.


150-154: Pass lr/grad_norm at accumulation boundary only (good)

Logging only on update steps keeps metrics aligned to effective steps.

Ensure grad clipping threshold (1.0) is intended; adjust if different model sizes need another norm.

bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py (1)

141-172: Confirm FSDP2 gradient sync behavior under accumulation

No no_sync context is used here. With loss scaling, correctness holds, but you may incur extra allreduces per micro-batch. Please confirm FSDP2’s fully_shard() wrapper does not need an explicit no_sync (or similar) for accumulation efficiency; if available, wrap forward/backward with it for non-boundary micro-steps.

@cspades cspades force-pushed the cye/bnmo-recipe-gradacc branch from 2f6b3d2 to 2afdefe Compare October 14, 2025 18:44
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (5)
bionemo-recipes/recipes/esm2_native_te/.ruff.toml (1)

2-2: Scope the C901 suppression more narrowly.

Let’s avoid blanketing the entire recipe with a global C901 ignore; it makes it easy for future complex functions to slip in unnoticed. Please narrow the suppression (e.g., inline # noqa: C901, per-file ignore, or bumping mccabe.max-complexity only where justified).

bionemo-recipes/recipes/esm2_native_te/train_ddp.py (1)

145-147: Minor: zero_grad(set_to_none=True) for perf/memory.

Optional optimization to reduce memory writes and improve perf.

-                optimizer.zero_grad()
+                optimizer.zero_grad(set_to_none=True)
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (1)

164-167: Minor: zero_grad(set_to_none=True).

Small perf tweak.

-                optimizer.zero_grad()
+                optimizer.zero_grad(set_to_none=True)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)

75-85: Avoid hard-coded PAD id and in-place logits mutation.

  • Counting unpadded tokens via input_ids != 1 assumes pad_token_id=1. Prefer attention_mask if available.
  • Don’t mutate outputs.logits; use a local var to adjust dims before update.
-        self.num_unpadded_tokens += batch["input_ids"][batch["input_ids"] != 1].numel()
-        self.running_loss += outputs.loss.item()
-        # Handle sequence packing for torchmetrics calculation.
-        if outputs.logits.dim() < 3:
-            outputs.logits = outputs.logits.unsqueeze(0)
-        self.metrics["train/perplexity"].update(outputs.logits, batch["labels"])
+        if "attention_mask" in batch:
+            # Preferred: rely on attention_mask if provided
+            self.num_unpadded_tokens += batch["attention_mask"].sum().item()
+        else:
+            # Fallback: count non-pad tokens; try to read pad_token_id from batch/config if present, else assume 1
+            pad_id = batch.get("pad_token_id", 1) if isinstance(batch, dict) else 1
+            self.num_unpadded_tokens += (batch["input_ids"] != pad_id).sum().item()
+        self.running_loss += outputs.loss.item()
+        # Handle sequence packing for torchmetrics calculation without mutating outputs
+        logits = outputs.logits if outputs.logits.dim() >= 3 else outputs.logits.unsqueeze(0)
+        self.metrics["train/perplexity"].update(logits, batch["labels"])

117-117: Progress bar postfix uses pre-reset loss: OK but consider clarity.

Optional: compute once into a local variable to avoid recomputation and ensure clarity.

-                self._progress_bar.set_postfix({"loss": self.running_loss / self.grad_acc_steps})
+                avg_loss = self.running_loss / self.grad_acc_steps
+                self._progress_bar.set_postfix({"loss": avg_loss})
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2f6b3d2 and 2afdefe.

📒 Files selected for processing (6)
  • bionemo-recipes/recipes/esm2_native_te/.ruff.toml (1 hunks)
  • bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml (1 hunks)
  • bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2 hunks)
  • bionemo-recipes/recipes/esm2_native_te/train_ddp.py (3 hunks)
  • bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py (2 hunks)
  • bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml
  • bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py
🧰 Additional context used
🧬 Code graph analysis (2)
bionemo-recipes/recipes/esm2_native_te/train_ddp.py (3)
bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py (1)
  • main (44-189)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)
  • log_micro_step (75-84)
  • log_step (86-126)
bionemo-recipes/recipes/esm2_native_te/checkpoint.py (2)
  • should_save_checkpoint (78-82)
  • save_checkpoint_ddp (125-158)
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (2)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)
  • log_micro_step (75-84)
  • log_step (86-126)
bionemo-recipes/recipes/esm2_native_te/checkpoint.py (2)
  • should_save_checkpoint (78-82)
  • save_checkpoint_mfsdp (237-280)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Analyze (rust)
🔇 Additional comments (4)
bionemo-recipes/recipes/esm2_native_te/train_ddp.py (1)

119-136: Good GA integration: correct no_sync cadence and loss scaling.

  • micro_step gating with DDP no_sync is correct.
  • FP8 autocast scope + per-micro backward with loss/grad_acc_steps looks right.
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (2)

70-74: Initialize accumulation counters: LGTM.

Counters are correctly initialized for accumulation.


122-127: Counter resets: LGTM.

Resets are correctly scoped to accumulation window end.

bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (1)

143-155: GA flow looks correct; verify model.sync() semantics

  • Micro-step accumulation and loss scaling look good.
  • Confirm that your mFSDP-wrapped model provides a sync() context manager that triggers gradient synchronization only on accumulation boundaries; otherwise use the appropriate API (e.g., no_sync()) to skip sync on other micro-steps.

@cspades cspades force-pushed the cye/bnmo-recipe-gradacc branch from 2afdefe to 9d269d4 Compare October 14, 2025 19:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants