feat(speechlm2): add Multi-Token Prediction (MTP) training support to SALMAutomodel#15791
Open
Slyne wants to merge 10 commits into
Open
feat(speechlm2): add Multi-Token Prediction (MTP) training support to SALMAutomodel#15791Slyne wants to merge 10 commits into
Slyne wants to merge 10 commits into
Conversation
Signed-off-by: SlyneD <slyned@nvidia.com>
Signed-off-by: SlyneD <slyned@nvidia.com>
Signed-off-by: SlyneD <slyned@nvidia.com>
Signed-off-by: SlyneD <slyned@nvidia.com>
Signed-off-by: SlyneD <slyned@nvidia.com>
…hether the self.training is a must for MTP Signed-off-by: SlyneD <slyned@nvidia.com>
Signed-off-by: SlyneD <slyned@nvidia.com>
salm_eval.py and vllm/salm/backends.py changes will be submitted on a dedicated evaluation branch, not this MTP training PR. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Replace inline getattr(self.llm, "mtp", None) checks with self._mtp_enabled - Remove redundant scaling fallback; use self._mtp_loss_scaling_factor directly - Drop mtp_loss_scaling_factor passthrough from forward pass (unused) - Remove redundant num_nextn_predict_layers kwarg to build_mtp_config_from_hf - Extract mtp_requested local bool to avoid repeating the enabled check Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Cover _mtp_enabled property (4 cases: no llm, no attr, None, attached) and forward() pass-through of mtp_per_depth_h (present / absent). All 6 tests pass in the nemo-automodel:26.04 container. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
pzelasko
reviewed
Jun 12, 2026
| # (per-rank-different) data. The head is attention-only (no experts), so plain | ||
| # FSDP2 over the DP mesh is the right wrapping. | ||
| if self._mtp_enabled: | ||
| self.llm.mtp = fully_shard(self.llm.mtp, mesh=fsdp_mesh) |
Collaborator
There was a problem hiding this comment.
we should guard this with an extra condition in case we finetune a model that does have MTP already in base checkpoint (e.g. Nemotron Super)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
Add Multi-Token Prediction (MTP) auxiliary training support to
SALMAutomodel, enabling the model to predict multiple future tokens per step using an auxiliary head attached to a NemotronV3 base checkpoint that ships without one.Collection: speechlm2
Changelog
_mtp_enabledproperty toSALMAutomodelfor a safe, pre-configure_model-safe check of whether the MTP head is attached_build_and_attach_mtp_head()to construct the NemotronV3 MTP head via Automodel factories and attach it asself.llm.mtpafter the LLM is loaded (the base checkpoint has no MTP head and the normalfrom_pretrainedinjection path does not work for this architecture)configure_model, since it is attached after the LLM's own FSDP wrappingdp_sizeto cancel FSDP gradient averagingmtp.enabled: truein the model config; relevant fields:loss_scaling_factor,num_nextn_predict_layers,hybrid_override_pattern,use_repeated_layerUsage
Add the below to the config file to enable MTP heads
The
mtp_lossscalar is logged on every training step (prog_bar=True) so you can monitor it in your experiment tracker alongside the main loss.GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information
This PR depends on the changes in the below PR: