Skip to content

feat(speechlm2): add Multi-Token Prediction (MTP) training support to SALMAutomodel#15791

Open
Slyne wants to merge 10 commits into
NVIDIA-NeMo:mainfrom
Slyne:slyne/mtp
Open

feat(speechlm2): add Multi-Token Prediction (MTP) training support to SALMAutomodel#15791
Slyne wants to merge 10 commits into
NVIDIA-NeMo:mainfrom
Slyne:slyne/mtp

Conversation

@Slyne

@Slyne Slyne commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Important

The Update branch button must only be pressed in very rare occassions.
An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.

What does this PR do ?

Add Multi-Token Prediction (MTP) auxiliary training support to SALMAutomodel, enabling the model to predict multiple future tokens per step using an auxiliary head attached to a NemotronV3 base checkpoint that ships without one.

Collection: speechlm2

Changelog

  • Add _mtp_enabled property to SALMAutomodel for a safe, pre-configure_model-safe check of whether the MTP head is attached
  • Add _build_and_attach_mtp_head() to construct the NemotronV3 MTP head via Automodel factories and attach it as self.llm.mtp after the LLM is loaded (the base checkpoint has no MTP head and the normal from_pretrained injection path does not work for this architecture)
  • FSDP2-shard the MTP head over the DP mesh in configure_model, since it is attached after the LLM's own FSDP wrapping
  • Compute MTP auxiliary loss in the training step using calculate_mtp_loss from nemo_automodel, normalized by the same - - global token count as the main loss and scaled by dp_size to cancel FSDP gradient averaging
  • Guard MTP against the packed-sequences (THD) path, which is not yet supported
  • MTP is opt-in via mtp.enabled: true in the model config; relevant fields: loss_scaling_factor, num_nextn_predict_layers, hybrid_override_pattern, use_repeated_layer

Usage

Add the below to the config file to enable MTP heads

# In your SALM Hydra config (e.g. conf/salm_automodel.yaml):
model:
  mtp:
    enabled: true
    num_nextn_predict_layers: 1
    loss_scaling_factor: 0.1
    hybrid_override_pattern: "*"
    use_repeated_layer: false

The mtp_loss scalar is logged on every training step (prog_bar=True) so you can monitor it in your experiment tracker alongside the main loss.

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed c
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • [✗] Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

This PR depends on the changes in the below PR:

SlyneD and others added 10 commits June 11, 2026 12:03
Signed-off-by: SlyneD <slyned@nvidia.com>
Signed-off-by: SlyneD <slyned@nvidia.com>
Signed-off-by: SlyneD <slyned@nvidia.com>
Signed-off-by: SlyneD <slyned@nvidia.com>
Signed-off-by: SlyneD <slyned@nvidia.com>
…hether the self.training is a must for MTP

Signed-off-by: SlyneD <slyned@nvidia.com>
Signed-off-by: SlyneD <slyned@nvidia.com>
salm_eval.py and vllm/salm/backends.py changes will be submitted on a
dedicated evaluation branch, not this MTP training PR.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Replace inline getattr(self.llm, "mtp", None) checks with self._mtp_enabled
- Remove redundant scaling fallback; use self._mtp_loss_scaling_factor directly
- Drop mtp_loss_scaling_factor passthrough from forward pass (unused)
- Remove redundant num_nextn_predict_layers kwarg to build_mtp_config_from_hf
- Extract mtp_requested local bool to avoid repeating the enabled check

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Cover _mtp_enabled property (4 cases: no llm, no attr, None, attached)
and forward() pass-through of mtp_per_depth_h (present / absent).
All 6 tests pass in the nemo-automodel:26.04 container.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 11, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

# (per-rank-different) data. The head is attention-only (no experts), so plain
# FSDP2 over the DP mesh is the right wrapping.
if self._mtp_enabled:
self.llm.mtp = fully_shard(self.llm.mtp, mesh=fsdp_mesh)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should guard this with an extra condition in case we finetune a model that does have MTP already in base checkpoint (e.g. Nemotron Super)

@svcnvidia-nemo-ci svcnvidia-nemo-ci added waiting-on-customer Waiting on the original author to respond waiting-on-maintainers Waiting on maintainers to respond and removed waiting-on-customer Waiting on the original author to respond waiting-on-maintainers Waiting on maintainers to respond labels Jun 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants