DFlash speculative decoding for MiniMax-M2.7 (FSDP2): auto mask-token, FSDP2 resume fixes, per-checkpoint draft export#1621
DFlash speculative decoding for MiniMax-M2.7 (FSDP2): auto mask-token, FSDP2 resume fixes, per-checkpoint draft export#1621yeyu-nvidia wants to merge 6 commits into
Conversation
Models without a <|mask|> token (e.g., MiniMax-M2.7) would fail with ValueError during DFlash training. Instead of requiring the user to manually set dflash_mask_token_id, add the token to the tokenizer and resize model embeddings automatically. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When slurm_config.requeue is True, set additional_parameters["requeue"] = True so nemo-run emits #SBATCH --requeue in the sbatch script. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1. main.py: When FSDP2 cpu_ram_efficient_loading is active, only rank 0 loads real weights on CPU; other ranks use meta device. FSDP2 distributes from rank 0. Also adds dp_replicate_size auto-computation so dp_replicate * dp_shard * cp == world_size. 2. core.py: Set retries=3 when requeue is requested. The nemo-run sbatch wrapper only calls scontrol requeue when TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT — retries=0 (the default) disabled requeue. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…heckpoint resume The Pydantic recipe refactor dropped fsdp2_buffer_patch.apply() and patch_accelerator() calls and added a buffer-to-CUDA block that moved DFlash buffers before FSDP wrapping. With cpu_ram_efficient_loading, non-rank-0 processes have meta-device params, causing _infer_parameter_dtype() to return fp32 instead of bf16 on resume. Also detects FSDP distributed checkpoints (no HF model files) and loads the base model instead of trying from_pretrained on them. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… patch _infer_parameter_dtype() reads the model's current param dtype to cast the broadcasted tensor. With cpu_ram_efficient_loading, non-rank-0 processes have fp32 meta-device params for DFlash, so _infer_parameter_dtype returns fp32 and _finish() casts the correctly- broadcasted bf16 tensor back to fp32. Use bcast_dtype (from rank 0) instead. Also prints dtype_check on all ranks to verify consistency. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The Pydantic-recipe refactor (7038dec) dropped DFlashExportCallback, which had exported the DFlash draft submodule after every checkpoint save. Without it, FSDP2 sharded checkpoints (pytorch_model_fsdp_0/, no model.safetensors) get no exported-checkpoint-{step}/, so downstream vLLM deployment / acceptance-length eval has nothing to load. The verify-only comment 'export happens during training via DFlashExportCallback' was left behind but the callback itself was gone. Restore the callback (gathers only the ~328MB draft submodule across shards via get_model_state_dict, so it works under SHARDED_STATE_DICT without materializing the full base model) and wire it into main.py for DFlash recipes. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
📝 WalkthroughWalkthroughThis PR introduces distributed-training infrastructure for speculative decoding: a DFlash draft-weight export callback, FSDP2 buffer-patch utilities for synchronized distributed loading and DTensor-safe gradient clipping, main script integration for both features, improved checkpoint format detection with fallback loading, automatic DFlash mask-token initialization, and Slurm requeue configuration support in the launcher. ChangesSpeculative Decoding Training Enhancements
Launcher Slurm Requeue Configuration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts
Comment |
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/eagle_utils.py (1)
53-53: 🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick winUpdate
__all__to includeDFlashExportCallback.The coding guidelines require defining the public API with
__all__. SinceDFlashExportCallbackis imported bymain.py(line 39), it should be exported.-__all__ = ["EagleOfflineDataCollator", "OfflineSupervisedDataset"] +__all__ = ["DFlashExportCallback", "EagleOfflineDataCollator", "OfflineSupervisedDataset"]As per coding guidelines: "Define the public API with
__all__at the top of each Python module."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/eagle_utils.py` at line 53, The module's public API list __all__ is missing DFlashExportCallback; update the __all__ declaration (which currently lists "EagleOfflineDataCollator" and "OfflineSupervisedDataset") to also include "DFlashExportCallback" so the symbol is exported for consumers like main.py that import it.
🧹 Nitpick comments (5)
examples/speculative_decoding/fsdp2_buffer_patch.py (4)
238-240: ⚡ Quick winUse
print_rank_0to avoid noisy logs in multi-rank environments.These print statements execute on every rank, which can produce excessive output on large clusters. Consider using
print_rank_0frommodelopt.torch.utilsor guarding with a rank check.+from modelopt.torch.utils import print_rank_0 + # In apply() function: - print("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility") + print_rank_0("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility") except Exception as e: - print(f"[fsdp2_buffer_patch] Patch skipped: {e}") + print_rank_0(f"[fsdp2_buffer_patch] Patch skipped: {e}")As per coding guidelines: "use
print_rank_0orwarn_rank_0to avoid noisy logs."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 238 - 240, The two print calls in fsdp2_buffer_patch (the success message and the except message around fsdp2_load_full_state_dict) should be replaced with rank-safe logging: import and call print_rank_0 from modelopt.torch.utils (or guard with a rank check) so messages only appear on rank 0; update the success print and the exception print to use print_rank_0 and include the exception variable in the error message (e) while keeping the same text context.
1-3: 💤 Low valueAdd
__all__to define the public API.Per coding guidelines, each Python module should define
__all__to make the public API explicit.+__all__ = ["apply", "patch_accelerator"] + """Monkey-patch for accelerate's fsdp2_load_full_state_dict buffer handling.As per coding guidelines: "Define the public API with
__all__at the top of each Python module."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 1 - 3, This module is missing an explicit public API; add a module-level __all__ declaration at the top (after the SPDX headers) listing the public names exported by this file (e.g. __all__ = ["Name1", "function_name", "CLASS_NAME"]), ensuring each symbol included matches the actual top-level functions/classes/variables defined later in the file; place the __all__ immediately below the license lines to satisfy the coding guideline.
322-326: ⚡ Quick winUse
print_rank_0here as well.def patch_accelerator(accelerator): """Replace accelerator's clip_grad_norm_ with FSDP2-safe version.""" accelerator.clip_grad_norm_ = _clip_grad_norm - print("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ " - "for FSDP2 DTensor compatibility") + from modelopt.torch.utils import print_rank_0 + print_rank_0("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ " + "for FSDP2 DTensor compatibility")As per coding guidelines: "use
print_rank_0orwarn_rank_0to avoid noisy logs."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 322 - 326, The patch_accelerator function currently uses print to log the patch; replace that call with print_rank_0 to follow logging guidelines. Update the function to call print_rank_0("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ for FSDP2 DTensor compatibility") and ensure print_rank_0 is imported at top of the module (the same utility used elsewhere), leaving accelerator.clip_grad_norm_ = _clip_grad_norm unchanged.
269-270: 💤 Low valueReturn value should be on the same device as gradients.
When there are no gradients, the function returns a CPU tensor. For consistency with the non-empty case (which returns
total_normon device), consider returning on the same device.if len(grads) == 0: - return torch.tensor(0.0) + device = parameters[0].device if parameters else "cpu" + return torch.tensor(0.0, device=device)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 269 - 270, The early-return creates a CPU tensor when grads is empty; change it to return a zero tensor on the same device as the gradients by selecting the first available grad device (e.g. device = next((g.device for g in grads if g is not None), torch.device('cpu'))) and then return torch.tensor(0.0, device=device) so the empty-case matches the device of the non-empty case that returns total_norm.examples/speculative_decoding/main.py (1)
297-303: ⚡ Quick winConsider gating debug output or removing before merge.
This debug print executes on every rank and will produce verbose output on large clusters. If this is temporary debugging code, consider removing it or guarding with a debug flag.
- rank = int(os.environ.get("RANK", 0)) - dtypes = {} - for name, p in trainer.model.named_parameters(): - dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype) - dtypes.setdefault(dt_key, []).append(name) - for dt, names in dtypes.items(): - print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})") + if os.environ.get("DEBUG_DTYPES"): + rank = int(os.environ.get("RANK", 0)) + dtypes = {} + for name, p in trainer.model.named_parameters(): + dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype) + dtypes.setdefault(dt_key, []).append(name) + for dt, names in dtypes.items(): + print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})")As per coding guidelines: "use
print_rank_0orwarn_rank_0to avoid noisy logs."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/main.py` around lines 297 - 303, The debug loop printing per-rank dtype info (uses rank, dtypes, iterating trainer.model.named_parameters()) should be gated or replaced to avoid noisy logs: either remove the prints or wrap them so only rank 0 logs (use existing print_rank_0 or warn_rank_0 utility) and/or guard with a debug flag (e.g., if DEBUG:). Update the block that builds dtypes and the final print to call print_rank_0 (or warn_rank_0) with the formatted message so only the main process emits the output, or conditionally execute the entire loop behind a debug configuration toggle.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tools/launcher/core.py`:
- Around line 280-287: The code assumes slurm_config.additional_parameters is a
mutable dict and mutates it directly, which can cause shared-state bugs; before
assigning to executor.additional_parameters (and before mutating it to set
"requeue"), validate and normalize slurm_config.additional_parameters to a plain
dict (e.g., treat None, mappings, or other types safely), create a shallow copy
for executor.additional_parameters, and then mutate that copy; also ensure
executor.retries is updated via executor.retries = max(executor.retries, 3) as
shown. Reference: slurm_config.additional_parameters,
executor.additional_parameters, and executor.retries.
---
Outside diff comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Line 53: The module's public API list __all__ is missing DFlashExportCallback;
update the __all__ declaration (which currently lists "EagleOfflineDataCollator"
and "OfflineSupervisedDataset") to also include "DFlashExportCallback" so the
symbol is exported for consumers like main.py that import it.
---
Nitpick comments:
In `@examples/speculative_decoding/fsdp2_buffer_patch.py`:
- Around line 238-240: The two print calls in fsdp2_buffer_patch (the success
message and the except message around fsdp2_load_full_state_dict) should be
replaced with rank-safe logging: import and call print_rank_0 from
modelopt.torch.utils (or guard with a rank check) so messages only appear on
rank 0; update the success print and the exception print to use print_rank_0 and
include the exception variable in the error message (e) while keeping the same
text context.
- Around line 1-3: This module is missing an explicit public API; add a
module-level __all__ declaration at the top (after the SPDX headers) listing the
public names exported by this file (e.g. __all__ = ["Name1", "function_name",
"CLASS_NAME"]), ensuring each symbol included matches the actual top-level
functions/classes/variables defined later in the file; place the __all__
immediately below the license lines to satisfy the coding guideline.
- Around line 322-326: The patch_accelerator function currently uses print to
log the patch; replace that call with print_rank_0 to follow logging guidelines.
Update the function to call print_rank_0("[fsdp2_buffer_patch] Patched
accelerator.clip_grad_norm_ for FSDP2 DTensor compatibility") and ensure
print_rank_0 is imported at top of the module (the same utility used elsewhere),
leaving accelerator.clip_grad_norm_ = _clip_grad_norm unchanged.
- Around line 269-270: The early-return creates a CPU tensor when grads is
empty; change it to return a zero tensor on the same device as the gradients by
selecting the first available grad device (e.g. device = next((g.device for g in
grads if g is not None), torch.device('cpu'))) and then return torch.tensor(0.0,
device=device) so the empty-case matches the device of the non-empty case that
returns total_norm.
In `@examples/speculative_decoding/main.py`:
- Around line 297-303: The debug loop printing per-rank dtype info (uses rank,
dtypes, iterating trainer.model.named_parameters()) should be gated or replaced
to avoid noisy logs: either remove the prints or wrap them so only rank 0 logs
(use existing print_rank_0 or warn_rank_0 utility) and/or guard with a debug
flag (e.g., if DEBUG:). Update the block that builds dtypes and the final print
to call print_rank_0 (or warn_rank_0) with the formatted message so only the
main process emits the output, or conditionally execute the entire loop behind a
debug configuration toggle.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: ec55dcce-a920-44ca-8e39-ee3167ca3eeb
📒 Files selected for processing (4)
examples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/fsdp2_buffer_patch.pyexamples/speculative_decoding/main.pytools/launcher/core.py
| additional_parameters=getattr(slurm_config, "additional_parameters", None) or {}, | ||
| ) | ||
| if getattr(slurm_config, "requeue", False): | ||
| executor.additional_parameters["requeue"] = True | ||
| # The nemo-run sbatch wrapper only calls `scontrol requeue` when | ||
| # TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT. retries=0 (the default) | ||
| # disables this, so bump it when requeue is requested. | ||
| executor.retries = max(executor.retries, 3) |
There was a problem hiding this comment.
Harden additional_parameters normalization before mutation.
At Line 280 and Line 283, this assumes additional_parameters is always a mutable mapping. Since these attrs are externally supplied, normalize/validate to dict before assignment and mutate a local copy to avoid shared-state side effects.
Proposed fix
- executor = run.SlurmExecutor(
+ raw_additional_parameters = getattr(slurm_config, "additional_parameters", None)
+ additional_parameters = {}
+ if raw_additional_parameters is not None:
+ if not isinstance(raw_additional_parameters, dict):
+ raise TypeError("slurm_config.additional_parameters must be a dict")
+ additional_parameters = dict(raw_additional_parameters)
+
+ executor = run.SlurmExecutor(
account=slurm_config.account,
partition=slurm_config.partition,
qos=slurm_config.qos,
@@
- additional_parameters=getattr(slurm_config, "additional_parameters", None) or {},
+ additional_parameters=additional_parameters,
)As per coding guidelines: "Validate external input once at the interface boundary; internal code can trust those checks and avoid redundant assertions".
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| additional_parameters=getattr(slurm_config, "additional_parameters", None) or {}, | |
| ) | |
| if getattr(slurm_config, "requeue", False): | |
| executor.additional_parameters["requeue"] = True | |
| # The nemo-run sbatch wrapper only calls `scontrol requeue` when | |
| # TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT. retries=0 (the default) | |
| # disables this, so bump it when requeue is requested. | |
| executor.retries = max(executor.retries, 3) | |
| raw_additional_parameters = getattr(slurm_config, "additional_parameters", None) | |
| additional_parameters = {} | |
| if raw_additional_parameters is not None: | |
| if not isinstance(raw_additional_parameters, dict): | |
| raise TypeError("slurm_config.additional_parameters must be a dict") | |
| additional_parameters = dict(raw_additional_parameters) | |
| executor = run.SlurmExecutor( | |
| account=slurm_config.account, | |
| partition=slurm_config.partition, | |
| qos=slurm_config.qos, | |
| additional_parameters=additional_parameters, | |
| ) | |
| if getattr(slurm_config, "requeue", False): | |
| executor.additional_parameters["requeue"] = True | |
| # The nemo-run sbatch wrapper only calls `scontrol requeue` when | |
| # TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT. retries=0 (the default) | |
| # disables this, so bump it when requeue is requested. | |
| executor.retries = max(executor.retries, 3) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tools/launcher/core.py` around lines 280 - 287, The code assumes
slurm_config.additional_parameters is a mutable dict and mutates it directly,
which can cause shared-state bugs; before assigning to
executor.additional_parameters (and before mutating it to set "requeue"),
validate and normalize slurm_config.additional_parameters to a plain dict (e.g.,
treat None, mappings, or other types safely), create a shallow copy for
executor.additional_parameters, and then mutate that copy; also ensure
executor.retries is updated via executor.retries = max(executor.retries, 3) as
shown. Reference: slurm_config.additional_parameters,
executor.additional_parameters, and executor.retries.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1621 +/- ##
==========================================
+ Coverage 76.55% 77.10% +0.55%
==========================================
Files 478 478
Lines 52035 52035
==========================================
+ Hits 39834 40121 +287
+ Misses 12201 11914 -287
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
What
Brings up DFlash block-diffusion speculative decoding for large MoE targets (MiniMax-M2.7, 229B) trained under accelerate FSDP2, and fixes the regressions that broke checkpoint resume and per-checkpoint draft export.
Commits
build_slurm_executor+ FSDP2 cpu_ram_efficient_loading for 229B on multi-node.fsdp2_buffer_patch.py): handle non-DTensor buffers infsdp2_load_full_state_dict, broadcast dtype codes from rank 0, and an FSDP2-safeclip_grad_norm_. Required because MiniMax-M2.7 pins transformers 4.57.x (no nativeParallelismConfig).DFlashExportCallback(this PR's headline): the Pydantic-recipe refactor (7038dec) dropped the callback that exported the draft submodule after each checkpoint save, leaving a stale "export happens during training via DFlashExportCallback" comment with no callback. FSDP2 SHARDED_STATE_DICT checkpoints carry nomodel.safetensors, so without it there is nothing for vLLM / acceptance-length eval to load. The callback gathers only the ~328 MB draft submodule across shards viaget_model_state_dict(..., submodules={dflash_module}, full_state_dict=True, cpu_offload=True)— works under SHARDED_STATE_DICT without materializing the 229B base — and writesexported-checkpoint-{step}/.Testing
🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes