Skip to content

Conversation

@adil-a
Copy link

@adil-a adil-a commented Nov 5, 2025

What does this PR do ?

Adds fp16 for policy training

https://wandb.ai/nvidia/automodel-rl/workspace?nw=6pzs4djqn28

The wandb above shows BF16 (v1 policy) and FP16 (v1 & v2 policies)

Signed-off-by: Hemil Desai <[email protected]>
Signed-off-by: Hemil Desai <[email protected]>
@adil-a adil-a marked this pull request as ready for review November 5, 2025 17:36
@adil-a adil-a requested a review from a team as a code owner November 5, 2025 17:36
@adil-a adil-a changed the title feat: fp16 loss scaling for DTensor policies feat: fp16 for DTensor policies Nov 5, 2025
@github-actions
Copy link

github-actions bot commented Nov 5, 2025

ℹ️ File Consistency Check

Check based on commit: 22608af (PR #1474 from hemil/fp16-loss-scaler)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 5, 2025

📝 Walkthrough

Walkthrough

Gradient scaling support for FP16 training is added to two policy worker classes by introducing ShardedGradScaler. Conditional logic gates scaler usage in backward pass, gradient unscaling, and optimizer step based on precision configuration.

Changes

Cohort / File(s) Summary
Gradient Scaling for FP16 Training
nemo_rl/models/policy/dtensor_policy_worker.py, nemo_rl/models/policy/dtensor_policy_worker_v2.py
Added ShardedGradScaler initialization when precision is float16; modified backward pass to use scaler.scale(loss).backward() when scaler exists; added gradient unscaling before clipping via scaler.unscale_() when scaler is present; updated optimizer step to use scaler.step() and scaler.update() when scaler exists, otherwise fallback to standard methods.

Sequence Diagram(s)

sequenceDiagram
    actor Trainer
    participant Worker as Policy Worker
    participant Scaler as ShardedGradScaler
    participant Optimizer

    Trainer->>Worker: Forward pass (compute loss)
    
    alt FP16 precision enabled
        Worker->>Scaler: scale(loss)
        Scaler-->>Worker: scaled_loss
        Worker->>Worker: scaled_loss.backward()
    else Other precision
        Worker->>Worker: loss.backward()
    end
    
    alt Using scaler
        Worker->>Scaler: unscale_(optimizer)
        Worker->>Worker: clip_grad_norm_()
        Worker->>Scaler: step(optimizer)
        Scaler->>Optimizer: Apply updates
        Worker->>Scaler: update()
    else Standard path
        Worker->>Worker: clip_grad_norm_()
        Worker->>Optimizer: step()
    end
    
    Optimizer-->>Worker: Weights updated
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

  • Consistent pattern applied homogeneously across two files reduces cognitive load
  • Conditional scaler logic is straightforward and non-intrusive
  • Pay close attention to ensure scaler initialization, unscaling, and stepping are correctly ordered in the training loop
  • Verify that fallback paths (when scaler is None) maintain existing behavior for non-FP16 training

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR implements major changes to gradient handling and FP16 loss scaling in core training paths, but PR description lacks any testing information, validation results, convergence metrics, or performance measurements. Add comprehensive testing information to PR description including convergence validation results, loss curve comparisons, performance metrics, and details of testing environment and configurations used.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: fp16 for DTensor policies' accurately reflects the main change: adding FP16 (float16) training support with gradient scaling to DTensor policy workers.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch hemil/fp16-loss-scaler

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
nemo_rl/models/policy/dtensor_policy_worker.py (1)

1918-1948: Persist scaler state in checkpoints to avoid training instability on resume.

The scaler maintains internal state (current scale factor, growth tracker, etc.) that should be saved and restored with checkpoints. Without this, resuming training will reset the scaler to its initial state, potentially causing convergence issues or repeated scale adjustments.

Apply this diff to save and load scaler state:

 def save_checkpoint(
     self,
     weights_path: str,
     optimizer_path: Optional[str] = None,
     tokenizer_path: Optional[str] = None,
 ) -> None:
     """Save a checkpoint of the model.
 
     the optimizer states are saved only if `optimizer` and `optimizer_path` are provided.
     """
     save_checkpoint(
         model=self.model,
         weights_path=weights_path,
         optimizer=self.optimizer if optimizer_path else None,
         scheduler=self.scheduler if optimizer_path else None,
+        scaler=self.scaler if optimizer_path else None,
         optimizer_path=optimizer_path,
         tokenizer=self.tokenizer if tokenizer_path else None,
         tokenizer_path=tokenizer_path,
     )

 def load_checkpoint(
     self, weights_path: str, optimizer_path: Optional[str] = None
 ) -> None:
     """Load a checkpoint into the model."""
     load_checkpoint(
         model=self.model,
         weights_path=weights_path,
         optimizer=self.optimizer if optimizer_path else None,
         scheduler=self.scheduler if optimizer_path else None,
+        scaler=self.scaler if optimizer_path else None,
         optimizer_path=optimizer_path,
     )

Note: You'll need to verify that the save_checkpoint and load_checkpoint utility functions support the scaler parameter. If not, update those functions accordingly.

nemo_rl/models/policy/dtensor_policy_worker_v2.py (1)

1879-1933: Persist scaler state in checkpoints to avoid training instability on resume.

Same issue as in the non-v2 worker: scaler state must be saved and restored to maintain consistent loss scaling behavior across checkpoint resume.

Apply this diff:

 def save_checkpoint(
     self,
     weights_path: str,
     optimizer_path: Optional[str] = None,
     tokenizer_path: Optional[str] = None,
     checkpointing_cfg: Optional[CheckpointingConfig] = None,
 ) -> None:
     """Save a checkpoint of the model.
 
     the optimizer states are saved only if `optimizer` and `optimizer_path` are provided.
     """
     if checkpointing_cfg is None:
         raise ValueError(
             "checkpointing_cfg must be provided when saving checkpoint"
         )
 
     # Extract only the checkpointing configuration keys that exist
     checkpoint_kwargs = {
         key: value
         for key, value in checkpointing_cfg.items()
         if key
         in {
             "model_save_format",
             "save_consolidated",
             "is_peft",
             "peft_config",
         }
     }
 
     save_checkpoint(
         model=self.model,
         weights_path=weights_path,
         optimizer=self.optimizer if optimizer_path else None,
         scheduler=self.scheduler if optimizer_path else None,
+        scaler=self.scaler if optimizer_path else None,
         optimizer_path=optimizer_path,
         tokenizer=self.tokenizer if tokenizer_path else None,
         tokenizer_path=tokenizer_path,
         model_state_dict_keys=self.model_state_dict_keys,
         **checkpoint_kwargs,
     )

 def load_checkpoint(
     self,
     weights_path: str,
     optimizer_path: Optional[str] = None,
 ) -> None:
     """Load a checkpoint into the model."""
     load_checkpoint(
         model=self.model,
         weights_path=weights_path,
         optimizer=self.optimizer if optimizer_path else None,
         scheduler=self.scheduler if optimizer_path else None,
+        scaler=self.scaler if optimizer_path else None,
         optimizer_path=optimizer_path,
     )

Note: Verify that nemo_rl.utils.automodel_checkpoint.save_checkpoint and load_checkpoint support the scaler parameter.

🧹 Nitpick comments (4)
nemo_rl/models/policy/dtensor_policy_worker.py (2)

202-206: Scaler initialization is correct but consider making growth_interval configurable.

The initialization logic properly gates scaler creation to FP16 precision only. However, growth_interval=400 is more aggressive than PyTorch's default (2000). While this may accelerate scale factor recovery, it could increase the risk of overflow if gradients are frequently unstable.

Consider making growth_interval a configuration parameter:

-        # Initialize gradient scaler for float16 training
-        if self.dtype == torch.float16:
-            self.scaler = ShardedGradScaler(growth_interval=400)
-        else:
-            self.scaler = None
+        # Initialize gradient scaler for float16 training
+        if self.dtype == torch.float16:
+            growth_interval = self.cfg.get("fp16_scaler_growth_interval", 400)
+            self.scaler = ShardedGradScaler(growth_interval=growth_interval)
+        else:
+            self.scaler = None

913-920: Consider adding scaler state to training metrics for observability.

Including the current loss scale factor in metrics would help monitor gradient scaling behavior and diagnose training issues related to FP16 precision.

Add scaler scale to metrics:

         metrics = {
             "global_loss": global_loss.cpu(),
             "grad_norm": grad_norm,
             "rank": torch.distributed.get_rank(),
             "gpu_name": torch.cuda.get_device_name(),
             "model_dtype": self.dtype,
+            "loss_scale": self.scaler.get_scale() if self.scaler is not None else None,
             "all_mb_metrics": dict(mb_metrics),
         }
nemo_rl/models/policy/dtensor_policy_worker_v2.py (2)

165-169: Scaler initialization is correct but consider making growth_interval configurable.

Same recommendation as the non-v2 worker: the hardcoded growth_interval=400 is more aggressive than the PyTorch default and should ideally be configurable via the policy config.

Apply the same refactor as suggested for dtensor_policy_worker.py:

-        # Initialize gradient scaler for float16 training
-        if self.dtype == torch.float16:
-            self.scaler = ShardedGradScaler(growth_interval=400)
-        else:
-            self.scaler = None
+        # Initialize gradient scaler for float16 training
+        if self.dtype == torch.float16:
+            growth_interval = self.cfg.get("fp16_scaler_growth_interval", 400)
+            self.scaler = ShardedGradScaler(growth_interval=growth_interval)
+        else:
+            self.scaler = None

888-895: Consider adding scaler state to training metrics for observability.

Same recommendation as the non-v2 worker: include the loss scale factor in metrics.

         metrics = {
             "global_loss": global_loss.cpu(),
             "grad_norm": grad_norm,
             "rank": torch.distributed.get_rank(),
             "gpu_name": torch.cuda.get_device_name(),
             "model_dtype": self.dtype,
+            "loss_scale": self.scaler.get_scale() if self.scaler is not None else None,
             "all_mb_metrics": dict(mb_metrics),
         }
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8762f57 and 22608af.

📒 Files selected for processing (2)
  • nemo_rl/models/policy/dtensor_policy_worker.py (5 hunks)
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py (5 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (8)
nemo_rl/models/policy/dtensor_policy_worker.py (4)

36-36: LGTM! Correct import for FSDP-compatible gradient scaling.

The ShardedGradScaler is the appropriate choice for FSDP2-based models and will handle gradient scaling across distributed shards.


853-856: LGTM! Backward pass correctly applies gradient scaling.

The conditional logic properly applies scaler.scale(loss).backward() for FP16 training and falls back to standard loss.backward() otherwise. This is the correct pattern for mixed-precision training.


865-867: Critical: Gradient unscaling before clipping is correctly implemented.

Unscaling gradients before computing norms and clipping is essential—otherwise, the clipping threshold would be applied to scaled gradients, leading to incorrect gradient magnitudes. This implementation is correct.


884-888: LGTM! Optimizer step correctly uses scaler.

The scaler's step() method will skip the optimizer update if gradients contain inf/nan, and update() adjusts the scale factor for the next iteration. This is the correct usage pattern.

nemo_rl/models/policy/dtensor_policy_worker_v2.py (4)

58-58: LGTM! Correct import for FSDP-compatible gradient scaling.

Consistent with the non-v2 worker implementation.


828-831: LGTM! Backward pass correctly applies gradient scaling.

Consistent with the non-v2 worker implementation.


840-842: Critical: Gradient unscaling before clipping is correctly implemented.

Consistent with the non-v2 worker—correctly unscales before norm computation and clipping.


859-863: LGTM! Optimizer step correctly uses scaler.

Consistent with the non-v2 worker implementation.

Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joyang-nv to review

@terrykong terrykong requested a review from joyang-nv November 5, 2025 17:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants