diff --git a/docs/source/dataloaders.rst b/docs/source/dataloaders.rst index 8a7ed848b8a8..20fd0f2f0b90 100644 --- a/docs/source/dataloaders.rst +++ b/docs/source/dataloaders.rst @@ -685,3 +685,61 @@ Other, more exotic configurations: * With ``seed="trng"``, the base random seed itself will be drawn using a TRNG. It will be different on each GPU training process. This setting is not recommended. * With ``seed="randomized"``, the base random seed is set to Python's global RNG seed. It might be different on each GPU training process. This setting is not recommended. + +CP/TP-safe batches with ``BroadcastingDataLoader`` +--------------------------------------------------- + +Context-parallel (CP) and tensor-parallel (TP) training require all ranks +within the same ``(cp, tp)`` sub-mesh of a DP slot to process the **same** +global batch each step — CP shards the sequence dimension and TP shards +the feature dimension, so a divergent global batch breaks the per-rank +shape contract that CP/TP collectives assume. + +Independent Lhotse loaders on each rank with ``shard_seed="randomized"`` +guarantee that *seeded* shard cursors line up, but they don't protect +against background-thread non-determinism (``concurrent_bucketing``, +worker scheduling jitter, etc.). The empirical signature is per-rank +``cu_seqlens`` divergence at a fraction of training steps, which then +deadlocks NCCL collectives with mismatched shapes. + +The :class:`~nemo.collections.common.data.lhotse.broadcasting.BroadcastingDataLoader` +fixes this at the data layer: construct the real Lhotse loader on a +single DP-source rank (``cp_rank == 0`` and ``tp_rank == 0``) and let the +wrapper broadcast each batch to the other ranks in the ``(cp, tp)`` +sub-mesh over NCCL. Iteration ends in lockstep via a continue/stop +broadcast — no length needs to be known up-front. + +.. code-block:: python + + from torch.distributed.device_mesh import init_device_mesh + + from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config + from nemo.collections.common.data.lhotse.broadcasting import ( + BroadcastingDataLoader, + is_dp_source_rank, + ) + + mesh = init_device_mesh("cuda", (dp, cp, tp), mesh_dim_names=("dp", "cp", "tp")) + + if is_dp_source_rank(mesh): + source = get_lhotse_dataloader_from_config( + config=cfg.train_ds, + global_rank=dp_rank, + world_size=dp_size, + dataset=dataset, + tokenizer=tokenizer, + ) + else: + source = None + + return BroadcastingDataLoader(source=source, device_mesh=mesh) + +The wrapper delegates ``state_dict`` / ``load_state_dict`` to the source +loader on the source rank (no-ops on non-source ranks), so checkpoint and +resume keep working transparently with regular ``DataLoader``, +``torchdata.StatefulDataLoader``, or any other source object that +implements those methods. + +The wrapper is a no-op when ``device_mesh`` is ``None`` or every named +axis present in the mesh has size 1, so the same call site works for +single-GPU, DDP-only, and CP/TP runs without a separate code path. diff --git a/docs/source/speechlm2/configs.rst b/docs/source/speechlm2/configs.rst index eeadb378dccd..dec918d805b9 100644 --- a/docs/source/speechlm2/configs.rst +++ b/docs/source/speechlm2/configs.rst @@ -229,6 +229,32 @@ Defaults come from Automodel's ``BackendConfig`` and auto-select TransformerEngi DeepEP when available; override here to pin a specific backend (for example, ``attn: sdpa`` to bypass TE). +**Packed sequences (THD):** + +.. code-block:: yaml + + model: + packed_sequences: true # default false (right-padded BSHD path) + automodel_backend: + attn: te # THD path dispatches TE varlen FlashAttention + +When ``packed_sequences`` is true, ``SALMAutomodel.prepare_inputs`` packs +each minibatch into a single flat ``[T_total, H]`` sequence with a +``cu_seqlens`` index instead of right-padding to ``[B, T_max, H]``. +``SALMAutomodel`` then forwards the THD metadata (``qkv_format``, +``cu_seqlens``, ``position_ids``, ``max_seqlen``) through ``forward()`` to +the LLM. The TE attention preprocessor splits the singular ``max_seqlen`` +into the ``max_seqlen_q`` / ``max_seqlen_kv`` pair that +``DotProductAttention`` requires for ``qkv_format="thd"``. The packing also +rounds each utterance's flat length up to a multiple of ``2 * cp_size`` so +the same THD batch satisfies TE's CP DualChunkSwap contract — see the +"Context Parallelism (CP)" subsection in +:doc:`training_and_scaling` for the recommended pairing with ``cp_size > 1``. + +Padding overhead drops from ``O(B * (T_max - T_avg))`` to +``O(per-utt rounding to 2*cp_size)``. Throughput improvement scales with +the variance of utterance lengths in your bucketing. + DuplexS2SModel Configuration ----------------------------- diff --git a/docs/source/speechlm2/training_and_scaling.rst b/docs/source/speechlm2/training_and_scaling.rst index e1f5ec56ae56..4e213319538a 100644 --- a/docs/source/speechlm2/training_and_scaling.rst +++ b/docs/source/speechlm2/training_and_scaling.rst @@ -183,8 +183,93 @@ For distributed inference, launch with ``torchrun``: inputs=path/to/manifest \ ep_size=2 -Configuration -^^^^^^^^^^^^^ +Packed Sequences (THD) +"""""""""""""""""""""" + +``SALMAutomodel`` supports an opt-in packed-sequence (``THD``) training and +validation path that concatenates per-utterance text + audio embeddings into +a single flat ``[T_total, H]`` sequence with a ``cu_seqlens`` index, instead +of right-padding into the standard ``[B, T_max, H]`` (``BSHD``) layout. TE's +varlen FlashAttention then operates segment-by-segment without ever attending +across utterances, and Mamba's ``seq_idx`` is derived from the same +``cu_seqlens`` so SSM state resets at document boundaries. + +For variable-length speech batches the padding overhead is substantial — the +``BSHD`` layout pays ``B * (T_max - T_avg)`` wasted compute per minibatch, +``THD`` pays only the per-utterance rounding to a multiple of ``2*cp_size`` +(needed for TE's CP DualChunkSwap pattern). Throughput improvement scales +with the variance of utterance lengths. + +Enable per-batch: + +.. code-block:: yaml + + model: + packed_sequences: true # opt-in; default false (BSHD) + automodel_backend: + attn: te # THD path requires TE attention + +When ``packed_sequences`` is unset, the existing BSHD path is used unchanged. +Generate / inference always uses BSHD (it doesn't go through ``prepare_inputs``). + +Context Parallelism (CP) +"""""""""""""""""""""""" + +``SALMAutomodel`` supports context parallelism for long-audio training on +hybrid Mamba/attention LLMs (e.g. Nemotron-V3). CP shards the sequence +dimension across GPUs so per-rank activations and KV-cache memory scale as +``T / cp_size`` instead of ``T``; attention layers go through TE's +DualChunkSwap pattern and Mamba mixers go through hidden-parallel +all-to-all (``MambaContextParallel`` in NeMo Automodel). + +Enable via the strategy: + +.. code-block:: yaml + + trainer: + strategy: + _target_: nemo.collections.speechlm2.parts.parallel.AutomodelParallelStrategy + cp_size: 2 # context parallel size; must divide num_heads of every Mamba block + ep_size: 2 # may share the same ranks as CP + +**The THD packed-sequence path is the only supported configuration under +CP.** Each utterance is its own attention segment and the per-utterance +sequence rounding aligns naturally with CP's ``2*cp_size`` requirement. + +.. warning:: + **BSHD + CP is not supported.** TE's fused-attention CP path supports + ``causal`` but not ``padding_causal``, so the right-pad mask must be + dropped before the LLM. With the mask dropped, pad K/V leak into + real-token attention through the causal mask and the gradient through + the LoRA / projection parameters becomes ``NaN`` after the first + optimizer step (validated empirically: BSHD + CP=2 + EP=2 on a 2-GPU + run produces ``loss=4.62`` at step 1 then ``loss=nan`` from step 2 + onwards). This is independent of the TE/cuDNN backward issue + documented below — setting ``NVTE_FUSED_ATTN=0`` does not fix it. + Set ``model.packed_sequences: true`` to use the THD path instead. + +.. note:: + **CP-safe data loading is automatic.** The speechlm2 datamodule wraps + the Lhotse loader in + :class:`~nemo.collections.common.data.lhotse.broadcasting.BroadcastingDataLoader`, + so under CP/TP every batch is constructed once on the DP source rank + (``cp_rank == 0`` and ``tp_rank == 0``) and broadcast to its sub-mesh + peers. This eliminates per-rank Lhotse non-determinism (``concurrent_bucketing``, + worker scheduling jitter, etc.) as a source of NCCL deadlocks under CP. + See :doc:`/dataloaders` for the standalone API. + +.. note:: + **TE/THD exploding-gradients workaround on some GPUs.** On certain GPU + architectures (notably Blackwell ``sm_120``), the cuDNN backend that + TransformerEngine 2.14 picks for ``qkv_format="thd"`` with + ``attn_mask_type="padding_causal"`` returns correct forward activations + but gradients amplified 8×–960× per layer. Compounded across the LLM's + attention stack this drives gradients to ``1e22``-magnitudes at step 0, + the gradient-clip-by-norm computes ``1.0 / inf = 0``, and Adam's moments + eventually NaN. Force TE to dispatch FlashAttention instead of cuDNN by + setting ``NVTE_FUSED_ATTN=0`` in the launcher environment (requires + ``flash-attn`` to be installed for your GPU arch). The FlashAttention + THD/``padding_causal`` backward is gradient-correct on the same shapes. To configure parallelism, modify the ``trainer.strategy`` section in your YAML config: diff --git a/nemo/collections/common/data/lhotse/broadcasting.py b/nemo/collections/common/data/lhotse/broadcasting.py new file mode 100644 index 000000000000..ac4e6b7354ee --- /dev/null +++ b/nemo/collections/common/data/lhotse/broadcasting.py @@ -0,0 +1,234 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CP/TP-aware data loading. + +Under context-parallel (CP) and tensor-parallel (TP) training, all ranks in +the same ``(cp, tp)`` sub-mesh of a DP slot must process the **same** global +batch each step — CP shards the sequence dimension and TP shards the +feature dimension, so a divergent global batch breaks the per-rank shape +contract that CP/TP collectives assume. + +The fix: construct the dataloader on a single DP-source rank per slot and +broadcast each batch over NCCL to the other ranks in the ``(cp, tp)`` +sub-mesh, eliminating the entire class of nondeterminism bug regardless of +source (Lhotse ``concurrent_bucketing``, ``shard_seed: randomized``, worker +scheduling jitter, etc.). + +:class:`BroadcastingDataLoader` is the single-class API: + + # In the datamodule: + return BroadcastingDataLoader( + source=real_loader if is_dp_source_rank(mesh) else None, + device_mesh=mesh, + ) + +The wrapper hides the broadcast bookkeeping. ``state_dict`` / +``load_state_dict`` are delegated to the source loader on the source rank, +so checkpoint/resume works transparently with ``DataLoader``, +``torchdata.StatefulDataLoader``, or any other source object that +implements those methods. + +Iteration termination is handled with two broadcasts per step: a +continue/stop boolean followed by the batch. This works regardless of +whether the source loader exposes ``__len__`` (Lhotse training loaders +typically don't). +""" +from __future__ import annotations + +from typing import Any, Iterable, Iterator, Sequence + +import torch +import torch.distributed as dist + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def is_dp_source_rank( + device_mesh, + axes: tuple[str, ...] = ("cp", "tp"), +) -> bool: + """True iff this rank is the data-parallel source for its DP slot. + + A DP source rank has coordinate 0 along every named axis (e.g. ``cp_rank == 0`` + and ``tp_rank == 0``). Pass the real dataloader to + :class:`BroadcastingDataLoader` only on DP source ranks; pass ``None`` + on the others. + + Returns True unconditionally when ``device_mesh`` is None or every named + axis present in the mesh has size 1, so callers can short-circuit setup + logic on single-rank-per-DP-slot runs without a separate code path. + """ + if _is_noop(device_mesh, axes): + return True + present = _present_axes(device_mesh, axes) + return all(device_mesh[ax].get_local_rank() == 0 for ax in present) + + +def broadcast_batch( + batch: Any, + device_mesh, + axes: tuple[str, ...] = ("cp", "tp"), +) -> Any: + """Broadcast ``batch`` from the DP source rank to all ranks in the + sub-mesh covering ``axes``. Returns the source's batch on every rank. + + Low-level primitive used internally by :class:`BroadcastingDataLoader`. + Most callers should use the class wrapper rather than calling this + directly. + + No-op (returns ``batch`` unchanged) when ``device_mesh`` is None, every + present named axis has size 1, or distributed isn't initialized. + """ + if _is_noop(device_mesh, axes): + return batch + if not (dist.is_available() and dist.is_initialized()): + return batch + resolved = _resolve_group_and_source(device_mesh, axes) + if resolved is None: + return batch + group, src = resolved + obj_list = [batch] + dist.broadcast_object_list(obj_list, src=src, group=group, device=_broadcast_device(group)) + return obj_list[0] + + +class BroadcastingDataLoader: + """Thin wrapper around (real DataLoader | None) that broadcasts each + batch from the DP source rank to non-source ranks in the ``(cp, tp)`` + sub-mesh. + + Pass ``source=real_loader`` on the DP source rank (``cp_rank == 0`` and + ``tp_rank == 0``); pass ``source=None`` on every other rank. Iteration + issues two broadcasts per step on every rank: a continue/stop boolean + followed by the batch. After the source loader is exhausted, the + continue broadcast is False and iteration ends in lockstep on all + ranks regardless of whether the source exposes ``__len__``. + + ``state_dict`` / ``load_state_dict`` are delegated to the source on the + source rank (no-ops on non-source ranks), so checkpoint/resume keeps + working transparently with ``torch.utils.data.DataLoader``, + ``torchdata.StatefulDataLoader``, or any other source that implements + those methods. + + No-op when ``device_mesh`` is None or every named axis present has + size 1 — iteration delegates to the source loader unchanged. + """ + + def __init__( + self, + source: Iterable | None, + device_mesh, + axes: tuple[str, ...] = ("cp", "tp"), + ): + self._source = source + self._mesh = device_mesh + self._axes = axes + if not _is_noop(device_mesh, axes): + self._is_source = is_dp_source_rank(device_mesh, axes) + if self._is_source and source is None: + raise ValueError("BroadcastingDataLoader on a DP source rank requires a non-None source") + + def __iter__(self) -> Iterator[Any]: + if _is_noop(self._mesh, self._axes): + if self._source is None: + return + yield from self._source + return + if self._is_source: + for batch in self._source: + broadcast_batch(True, self._mesh, self._axes) + broadcast_batch(batch, self._mesh, self._axes) + yield batch + broadcast_batch(False, self._mesh, self._axes) + else: + while True: + keep_iterating = broadcast_batch(None, self._mesh, self._axes) + if not keep_iterating: + return + batch = broadcast_batch(None, self._mesh, self._axes) + yield batch + + def __len__(self) -> int: + # Pass-through when the source defines __len__; raise TypeError + # otherwise (matching Lhotse's typical behavior, which Lightning + # already handles by treating the loader as iterable-style). + if self._source is not None: + return len(self._source) + raise TypeError("BroadcastingDataLoader on non-source rank has no defined length") + + def state_dict(self) -> dict: + if self._source is not None and hasattr(self._source, "state_dict"): + return self._source.state_dict() + return {} + + def load_state_dict(self, state_dict) -> None: + if self._source is not None and hasattr(self._source, "load_state_dict"): + self._source.load_state_dict(state_dict) + + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + + +# Cache: (id(device_mesh), tuple_of_axes) -> (process_group, source_global_rank). +# Sub-mesh creation calls ``_flatten`` which materializes a process group; +# we don't want to repeat that per training step. +_GROUP_CACHE: dict[tuple[int, tuple[str, ...]], tuple[Any, int]] = {} + + +def _present_axes(device_mesh, axes: Sequence[str]) -> tuple[str, ...]: + if device_mesh is None: + return () + names = device_mesh.mesh_dim_names or () + return tuple(ax for ax in axes if ax in names) + + +def _is_noop(device_mesh, axes: Sequence[str]) -> bool: + if device_mesh is None: + return True + present = _present_axes(device_mesh, axes) + if not present: + return True + return all(device_mesh[ax].size() == 1 for ax in present) + + +def _broadcast_device(group) -> torch.device: + backend = dist.get_backend(group) + if backend == "nccl" and torch.cuda.is_available(): + return torch.device(f"cuda:{torch.cuda.current_device()}") + return torch.device("cpu") + + +def _resolve_group_and_source(device_mesh, axes: Sequence[str]): + if _is_noop(device_mesh, axes): + return None + present = _present_axes(device_mesh, axes) + cache_key = (id(device_mesh), present) + cached = _GROUP_CACHE.get(cache_key) + if cached is not None: + return cached + + if len(present) == 1: + sub = device_mesh[present[0]] + else: + sub = device_mesh[present]._flatten(mesh_dim_name="_".join(present)) + + group = sub.get_group() + source_global_rank = int(sub.mesh.flatten()[0].item()) + _GROUP_CACHE[cache_key] = (group, source_global_rank) + return group, source_global_rank diff --git a/nemo/collections/speechlm2/data/datamodule.py b/nemo/collections/speechlm2/data/datamodule.py index 0e95542e4ede..df8945dfa865 100644 --- a/nemo/collections/speechlm2/data/datamodule.py +++ b/nemo/collections/speechlm2/data/datamodule.py @@ -18,6 +18,7 @@ from nemo.collections.common.data.fallback import FallbackDataset from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.data.lhotse.broadcasting import BroadcastingDataLoader, is_dp_source_rank from nemo.collections.common.tokenizers import TokenizerSpec @@ -72,13 +73,18 @@ def __init__(self, cfg, tokenizer: TokenizerSpec, dataset: torch.utils.data.Data def train_dataloader(self): if "train_ds" not in self.cfg: return None - return get_lhotse_dataloader_from_config( - config=self.cfg.train_ds, - global_rank=self._get_dp_rank(), - world_size=self._get_world_size(), - dataset=FallbackDataset(self.dataset), - tokenizer=self.tokenizer, - ) + mesh = self._get_device_mesh() + if is_dp_source_rank(mesh): + source = get_lhotse_dataloader_from_config( + config=self.cfg.train_ds, + global_rank=self._get_dp_rank(), + world_size=self._get_world_size(), + dataset=FallbackDataset(self.dataset), + tokenizer=self.tokenizer, + ) + else: + source = None + return BroadcastingDataLoader(source=source, device_mesh=mesh) def val_dataloader(self): if "validation_ds" not in self.cfg: @@ -117,13 +123,18 @@ def _build_test_dataloader(self, cfg: DictConfig) -> torch.utils.data.DataLoader with open_dict(cfg): cfg.force_finite = True cfg.force_map_dataset = True - return get_lhotse_dataloader_from_config( - config=cfg, - global_rank=self._get_dp_rank(), - world_size=self._get_world_size(), - dataset=self.dataset, - tokenizer=self.tokenizer, - ) + mesh = self._get_device_mesh() + if is_dp_source_rank(mesh): + source = get_lhotse_dataloader_from_config( + config=cfg, + global_rank=self._get_dp_rank(), + world_size=self._get_world_size(), + dataset=self.dataset, + tokenizer=self.tokenizer, + ) + else: + source = None + return BroadcastingDataLoader(source=source, device_mesh=mesh) # Multiple validation/test dataloaders. # Config looks like: @@ -145,6 +156,13 @@ def _build_test_dataloader(self, cfg: DictConfig) -> torch.utils.data.DataLoader dloaders[name] = self._build_test_dataloader(item) return CombinedLoader(dloaders, mode="max_size") + def _get_device_mesh(self): + if not (torch.distributed.is_available() and torch.distributed.is_initialized()): + return None + if hasattr(self.trainer, "model") and hasattr(self.trainer.model, "device_mesh"): + return self.trainer.model.device_mesh + return None + def _get_dp_rank(self): if torch.distributed.is_available() and torch.distributed.is_initialized(): if ( diff --git a/nemo/collections/speechlm2/models/salm_automodel.py b/nemo/collections/speechlm2/models/salm_automodel.py index f759ac01bcc7..0f139461076e 100644 --- a/nemo/collections/speechlm2/models/salm_automodel.py +++ b/nemo/collections/speechlm2/models/salm_automodel.py @@ -149,6 +149,7 @@ def forward( input_embeds: Tensor, attention_mask: Tensor = None, cache=None, + **llm_kwargs, ) -> dict[str, Tensor]: """ Implements a fully offline forward pass through the entire model. @@ -156,14 +157,21 @@ def forward( |speech and text embeddings| -> |llm| -> |lm_head| -> |token ids| + ``llm_kwargs`` carries optional THD/packed-sequence metadata + (``qkv_format``, ``cu_seqlens``, ``position_ids``, ``max_seqlen``); + it is empty for the BSHD path. """ - # input_embeds and out: (B, T, H) + # input_embeds: (B, T, H) for BSHD or (T_total, H) for THD packed + # (the THD shape mirrors Automodel's _shard_thd_chunk_for_te output — + # the model squeezes 3D inputs internally when qkv_format=="thd", so + # passing 2D directly skips that hop) out = self.llm( inputs_embeds=input_embeds, attention_mask=attention_mask, past_key_values=cache, use_cache=cache is not None, return_dict=True, + **llm_kwargs, ) if not isinstance(out, dict): # NeMo Automodel doesn't respect return_dict=True yet @@ -186,39 +194,58 @@ def prepare_inputs(self, batch: dict): * Take care of any necessary slicing to align the shapes of source audio, target audio, and target token ids. """ - # Source audio encoding. - # Input audio: (B, T_samples) - # Audio embeddings: (B, T, H) - audio_embs = encode_audio_with_optional_chunking( + from nemo.collections.speechlm2.parts.cp_helpers import encode_audio_with_cp_distribution, get_cp_mesh + + cp_mesh, _, _ = get_cp_mesh(getattr(self, "_device_mesh", None)) + + # Source audio encoding (distributed across CP ranks when CP is active). + # Input audio: (B_aud, T_samples) → list of (L_i, H) embeddings. + audio_embs = encode_audio_with_cp_distribution( self.perception, batch["audios"], batch["audio_lens"], chunk_size_seconds=self.cfg.get("encoder_chunk_size_seconds", None), sampling_rate=self.sampling_rate, + cp_mesh=cp_mesh, ) input_ids_to_embed = torch.where(batch["input_ids"] == self.audio_locator_tag_id, 0, batch["input_ids"]) text_embs = self._embed_tokens(input_ids_to_embed) + target_ids_full = batch["input_ids"].where(batch["loss_mask"], -100) # CrossEntropyLoss().ignore_index + + # Packed-sequence (THD) path — used for both training and validation when enabled. + # Generate stays on the BSHD path (it doesn't go through prepare_inputs). + if self.cfg.get("packed_sequences", False): + from nemo.collections.speechlm2.parts.packed_sequences import prepare_packed_llm_inputs + + return prepare_packed_llm_inputs( + input_ids=batch["input_ids"], + text_embs=text_embs, + audio_embs=audio_embs, + target_ids=target_ids_full, + padding_id=self.text_pad_id, + placeholder_id=self.audio_locator_tag_id, + device_mesh=getattr(self, "_device_mesh", None), + ) + input_embs, target_ids, attention_mask = replace_placeholders_and_build_targets( input_ids=batch["input_ids"], embeds=text_embs, padding_id=self.text_pad_id, placeholder_id=self.audio_locator_tag_id, replacements=audio_embs, - target_ids=batch["input_ids"].where(batch["loss_mask"], -100), # CrossEntropyLoss().ignore_index + target_ids=target_ids_full, ) input_embs = input_embs[:, :-1] attention_mask = attention_mask[:, :-1] target_ids = target_ids[:, 1:] - # Combine target audio and text into a single tensor to slice them together. - # It will also help us truncate the sequence lengths to be divisible by TP world size, - # when TP is enabled. - # Input ids: (B, T, K+1) + # BSHD path runs only when CP is inactive (the fit-start validator + # rejects BSHD + CP > 1, see _validate_parallelism_compatibility). + # Truncate the seq dim to be divisible by tp_size so sequence + # parallelism doesn't reshape the input under us. if self._use_tp: - tp_world_size = self.device_mesh["tp"].size() - if (remainder := (input_embs.shape[1] - 1) % tp_world_size) != 0: - # Truncate some tokens from the end to make the sequence length shape divisible by tensor parallelism - # world size. Otherwise, sequence parallelism will change the input shape making leading to mismatches. + tp_size = self.device_mesh["tp"].size() + if (remainder := (input_embs.shape[1] - 1) % tp_size) != 0: input_embs = input_embs[:, :-remainder] attention_mask = attention_mask[:, :-remainder] target_ids = target_ids[:, :-remainder] @@ -227,13 +254,44 @@ def prepare_inputs(self, batch: dict): "input_embeds": input_embs, "attention_mask": attention_mask, "target_ids": target_ids, + "llm_kwargs": {}, } def on_fit_start(self) -> None: """Configure the MoE aux-loss backward scaler to cancel FSDP's gradient averaging (see ``_configure_moe_aux_loss_scaler``).""" + self._validate_parallelism_compatibility() self._configure_moe_aux_loss_scaler() + def _validate_parallelism_compatibility(self) -> None: + """Raise on known-incompatible THD/CP/backend configurations. + + Delegates to :func:`nemo.collections.speechlm2.parts.parallel.validate_parallelism_compatibility` + with the runtime-derived values from this model's config and device mesh. + """ + import os + + from nemo.collections.speechlm2.parts.parallel import validate_parallelism_compatibility + + cp_size = 1 + device_mesh = getattr(self, "_device_mesh", None) + if device_mesh is not None: + names = device_mesh.mesh_dim_names or () + if "cp" in names: + cp_size = device_mesh["cp"].size() + + attn_backend = self.cfg.get("automodel_backend", {}).get("attn", "te") + nvte_fused_attn = os.environ.get("NVTE_FUSED_ATTN") + device_capability = torch.cuda.get_device_capability() if torch.cuda.is_available() else None + + validate_parallelism_compatibility( + packed_sequences=bool(self.cfg.get("packed_sequences", False)), + cp_size=cp_size, + attn_backend=attn_backend, + nvte_fused_attn=nvte_fused_attn, + device_capability=device_capability, + ) + def training_step(self, batch: dict, batch_idx: int): self._current_batch_idx = batch_idx for m in (self.perception.preprocessor, self.perception.encoder, self.llm): @@ -241,7 +299,11 @@ def training_step(self, batch: dict, batch_idx: int): m.eval() inputs = self.prepare_inputs(batch) - forward_outputs = self(inputs["input_embeds"], attention_mask=inputs["attention_mask"]) + forward_outputs = self( + inputs["input_embeds"], + attention_mask=inputs["attention_mask"], + **inputs.get("llm_kwargs", {}), + ) num_frames = (inputs["target_ids"] != -100).long().sum() # Match Automodel's training recipe: normalize CE by the *global* token count across @@ -260,9 +322,10 @@ def training_step(self, batch: dict, batch_idx: int): num_frames_global = num_frames_global.clamp(min=1) with loss_parallel(): + logits = forward_outputs["logits"] loss_sum = torch.nn.functional.cross_entropy( - forward_outputs["logits"].flatten(0, 1), # (B, T, Vt) -> (*, Vt) - inputs["target_ids"].flatten(0, 1), + logits.reshape(-1, logits.size(-1)), # BSHD (B,T,V) or THD (1,T,V) -> (*, V) + inputs["target_ids"].reshape(-1), # BSHD (B,T) or THD (T,) -> (*,) reduction="sum", ignore_index=-100, ) @@ -273,7 +336,12 @@ def training_step(self, batch: dict, batch_idx: int): with torch.no_grad(): loss_display = loss_sum.detach() / num_frames.clamp(min=1) - B, T = inputs["input_embeds"].shape[:2] + # Input embeds shape is (B, T, H) for BSHD or (T, H) for THD packed. + input_embeds = inputs["input_embeds"] + if input_embeds.dim() == 2: + B, T = 1, input_embeds.shape[0] + else: + B, T = input_embeds.shape[:2] ans = { "loss": loss, "learning_rate": ( @@ -318,13 +386,18 @@ def validation_step(self, batch: dict, batch_idx: int): if dataset_batch is None: continue # some dataset is exhausted inputs = self.prepare_inputs(dataset_batch) - forward_outputs = self(inputs["input_embeds"], attention_mask=inputs["attention_mask"]) + forward_outputs = self( + inputs["input_embeds"], + attention_mask=inputs["attention_mask"], + **inputs.get("llm_kwargs", {}), + ) num_frames = (inputs["target_ids"] != -100).long().sum() with loss_parallel(): + logits = forward_outputs["logits"] loss = ( torch.nn.functional.cross_entropy( - forward_outputs["logits"].flatten(0, 1), - inputs["target_ids"].flatten(0, 1), + logits.reshape(-1, logits.size(-1)), + inputs["target_ids"].reshape(-1), reduction="sum", ignore_index=-100, ) diff --git a/nemo/collections/speechlm2/parts/cp_helpers.py b/nemo/collections/speechlm2/parts/cp_helpers.py new file mode 100644 index 000000000000..f30ffe32f1a9 --- /dev/null +++ b/nemo/collections/speechlm2/parts/cp_helpers.py @@ -0,0 +1,146 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Context-Parallelism (CP) helpers for SALMAutomodel. + +These helpers consolidate the CP-shape work needed to feed THD packed +batches into a Nemotron-V3 LLM whose attention/Mamba layers were CP-wired +by the Automodel parallelizer (``set_context_parallel_group()`` / +``mixer.cp = MambaContextParallel(...)``). Two concerns: + +1. ``get_cp_mesh`` — read the CP submesh out of a device mesh, returning + ``(None, 1, 0)`` when CP is inactive so callers can short-circuit. +2. ``encode_audio_with_cp_distribution`` — distribute the audio encoder + forward across CP ranks so it isn't recomputed cp_size times. Pads to a + multiple of cp_size with dummy zero-audios so every rank participates in + FSDP all-gather; dummies are dropped after the post-encoder all-gather. +""" +from __future__ import annotations + +from typing import Optional + +import torch +import torch.distributed as dist +from torch import Tensor + +from nemo.collections.speechlm2.parts.encoder_chunking import encode_audio_with_optional_chunking + + +def get_cp_mesh(device_mesh) -> tuple[Optional[object], int, int]: + """Return ``(cp_mesh, cp_size, cp_rank)`` or ``(None, 1, 0)`` when CP is inactive.""" + if device_mesh is None: + return None, 1, 0 + names = device_mesh.mesh_dim_names or () + if "cp" not in names or device_mesh["cp"].size() <= 1: + return None, 1, 0 + cp_mesh = device_mesh["cp"] + cp_rank = dist.get_rank(group=cp_mesh.get_group()) + return cp_mesh, cp_mesh.size(), cp_rank + + +def encode_audio_with_cp_distribution( + perception, + audios: Tensor, + audio_lens: Tensor, + *, + chunk_size_seconds: Optional[float], + sampling_rate: int, + cp_mesh=None, +) -> list[Tensor]: + """Distribute the audio encoder forward across CP ranks. + + Falls back to :func:`encode_audio_with_optional_chunking` when ``cp_mesh is + None`` or there are no audios in the batch. + + With CP active, each rank encodes a contiguous slice of the audio batch + (rank ``r`` gets ``audios[r*per_rank : (r+1)*per_rank]`` where + ``per_rank = ceil(B_aud / cp_size)``). When ``B_aud`` is not a multiple of + ``cp_size`` the audio batch is right-padded with zero-audio dummies; every + rank still calls ``perception`` so FSDP all-gather and activation + checkpointing fire uniformly. The dummy length is set to the smallest real + audio length in the batch (guaranteed to satisfy the encoder's minimum- + length constraints since at least one real sample of that length already + does). + + After local encoding, each rank's variable-length embedding tensors are + zero-padded to a globally-consistent ``max_L`` and ``all_gather``ed across + the CP group. The full ordered list is reconstructed and dummies are + dropped, so the return value is identical on every CP rank. + """ + B_aud = int(audios.shape[0]) + if cp_mesh is None or B_aud == 0: + return encode_audio_with_optional_chunking( + perception, + audios, + audio_lens, + chunk_size_seconds=chunk_size_seconds, + sampling_rate=sampling_rate, + ) + + cp_size = cp_mesh.size() + cp_rank = dist.get_rank(group=cp_mesh.get_group()) + device = audios.device + + per_rank = (B_aud + cp_size - 1) // cp_size + B_padded = per_rank * cp_size + pad_n = B_padded - B_aud + + if pad_n > 0: + dummy_len = int(audio_lens.min().item()) + T_samp = audios.shape[1] + dummy_audios = torch.zeros(pad_n, T_samp, dtype=audios.dtype, device=device) + dummy_lens = torch.full((pad_n,), dummy_len, dtype=audio_lens.dtype, device=device) + audios = torch.cat([audios, dummy_audios], dim=0) + audio_lens = torch.cat([audio_lens, dummy_lens], dim=0) + + start = cp_rank * per_rank + end = start + per_rank + local_audios = audios[start:end] + local_audio_lens = audio_lens[start:end] + + local_embs = encode_audio_with_optional_chunking( + perception, + local_audios, + local_audio_lens, + chunk_size_seconds=chunk_size_seconds, + sampling_rate=sampling_rate, + ) + + # All-gather across CP. Variable-length: pad to a common max-L first. + H = local_embs[0].shape[-1] + local_max_L = max(e.shape[0] for e in local_embs) + max_L_t = torch.tensor(local_max_L, dtype=torch.long, device=device) + dist.all_reduce(max_L_t, op=dist.ReduceOp.MAX, group=cp_mesh.get_group()) + max_L = int(max_L_t.item()) + + local_stack = torch.zeros(per_rank, max_L, H, device=device, dtype=local_embs[0].dtype) + local_lens = torch.zeros(per_rank, dtype=torch.long, device=device) + for i, e in enumerate(local_embs): + local_stack[i, : e.shape[0]] = e + local_lens[i] = e.shape[0] + + gathered_stack = [torch.zeros_like(local_stack) for _ in range(cp_size)] + gathered_lens = [torch.zeros_like(local_lens) for _ in range(cp_size)] + dist.all_gather(gathered_stack, local_stack, group=cp_mesh.get_group()) + dist.all_gather(gathered_lens, local_lens, group=cp_mesh.get_group()) + + full_embs: list[Tensor] = [] + for r in range(cp_size): + for i in range(per_rank): + full_idx = r * per_rank + i + if full_idx >= B_aud: + break # dummy slot + L = int(gathered_lens[r][i].item()) + full_embs.append(gathered_stack[r][i, :L]) + + return full_embs diff --git a/nemo/collections/speechlm2/parts/packed_sequences.py b/nemo/collections/speechlm2/parts/packed_sequences.py new file mode 100644 index 000000000000..074de563c728 --- /dev/null +++ b/nemo/collections/speechlm2/parts/packed_sequences.py @@ -0,0 +1,297 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Packed-sequence (THD-format) helpers for SALMAutomodel training. + +The Nemotron-V3 LLM in `nemo_automodel` already accepts THD batches — TE +attention switches to varlen FlashAttention via `cu_seqlens`, and the Mamba +mixer derives `seq_idx` from the same `cu_seqlens` so SSM state resets at +document boundaries. The functions in this module concatenate a SALM +multi-utterance minibatch into a single packed sequence so the LLM is fed +`inputs_embeds` of shape ``[1, T_total, H]`` plus the THD metadata it needs. + +All tensor logic is kept here (no `SALMAutomodel` knowledge) so it is unit +testable on CPU. +""" +from __future__ import annotations + +from typing import Any, Optional + +import torch +from torch import Tensor + +from nemo.collections.speechlm2.models.salm import _unpad_inputs + + +def pack_audio_into_text_embeds( + input_ids: Tensor, + embeds: Tensor, + target_ids: Tensor, + replacements: list[Tensor], + padding_id: int, + placeholder_id: int, + cp_size: int = 1, + tp_size: int = 1, + ignore_index: int = -100, +) -> dict[str, Tensor]: + """Splice audio frames into per-utterance text embeddings and pack into THD. + + Mirrors :func:`replace_placeholders_and_build_targets` but emits a single + flat THD batch instead of a right-padded BSHD one. Labels are next-token + shifted *per utterance* before cross-utterance concatenation, so the LLM + can be called without any further shift. + + Args: + input_ids: ``[B, S]`` int64; left-padded. + embeds: ``[B, S, H]`` text-token embeddings (placeholder slots + are pre-zeroed by the caller; they get overwritten). + target_ids: ``[B, S]`` int64; ``-100`` outside assistant spans. + replacements: list of ``[L_i, H]`` audio-frame embeddings, one per + placeholder occurrence in row-major order. + padding_id: pad-token id in ``input_ids`` (used to strip left-pad + and to mark padding positions as ``ignore_index`` in + labels). + placeholder_id: the ``<|audio|>`` token id. + cp_size: per-utterance flat lengths are rounded up to a + multiple of ``2 * cp_size`` (TE-CP requirement). + tp_size: the last utterance's padded length is bumped so that + ``T_total % tp_size == 0`` (sequence-parallel). + ignore_index: label fill for audio-frame slots, padding slots, and + the last position of every utterance. + + Returns a dict with: + + - ``inputs_embeds`` ``[T_total, H]`` (2D, mirrors Automodel's + ``process_input_for_thd`` / ``_shard_thd_chunk_for_te`` + contract — no leading batch dim) + - ``labels`` ``[T_total]`` int64, already shifted + - ``position_ids`` ``[T_total]`` int64, resets to 0 per utt + - ``seq_lens`` ``[B, 1]`` int64, real per-utt flat lengths + - ``seq_lens_padded`` ``[B, 1]`` int64, post-rounding lengths + - ``cu_seqlens`` ``[B+1]`` int32, ``cumsum`` of ``seq_lens_padded`` + - ``max_seqlen`` int32 scalar, ``max(seq_lens_padded)`` + - ``qkv_format`` ``"thd"`` + """ + B = input_ids.shape[0] + H = embeds.shape[-1] + device = embeds.device + dtype = embeds.dtype + + # Strip left-padding so per-utt sequences are tight before splicing. + ids_unpad, embs_unpad, tgts_unpad = _unpad_inputs(input_ids, embeds, target_ids, padding_id) + + seq_embs: list[Tensor] = [] + seq_labs: list[Tensor] = [] + real_lens: list[int] = [] + rep_idx = 0 + + for i in range(B): + ids_i = ids_unpad[i] + emb_i = embs_unpad[i] + tgt_i = tgts_unpad[i] + placeholders = (ids_i == placeholder_id).nonzero(as_tuple=True)[0].tolist() + + emb_segments: list[Tensor] = [] + lab_segments: list[Tensor] = [] + prev = 0 + for p in placeholders: + if p > prev: + emb_segments.append(emb_i[prev:p]) + seg_lab = tgt_i[prev:p].clone() + seg_lab[ids_i[prev:p] == padding_id] = ignore_index + lab_segments.append(seg_lab) + rep = replacements[rep_idx] + rep_idx += 1 + emb_segments.append(rep) + lab_segments.append(torch.full((rep.shape[0],), ignore_index, dtype=torch.long, device=device)) + prev = p + 1 + if prev < ids_i.numel(): + emb_segments.append(emb_i[prev:]) + seg_lab = tgt_i[prev:].clone() + seg_lab[ids_i[prev:] == padding_id] = ignore_index + lab_segments.append(seg_lab) + + emb_cat = torch.cat(emb_segments, dim=0) # [L_i, H] + lab_cat = torch.cat(lab_segments, dim=0) # [L_i] + L = emb_cat.shape[0] + # Per-utterance next-token shift: labels[t] = orig[t+1], last slot is ignored. + lab_shift = torch.cat( + [lab_cat[1:], torch.full((1,), ignore_index, dtype=torch.long, device=device)], + dim=0, + ) + seq_embs.append(emb_cat) + seq_labs.append(lab_shift) + real_lens.append(L) + + if rep_idx != len(replacements): + raise ValueError( + f"Used {rep_idx} of {len(replacements)} audio replacements — " + f"placeholder occurrences in input_ids do not match replacements length." + ) + + # Round each utterance's length up to a multiple of 2*cp_size (TE-CP + # interleaves 2 chunks per rank); skip rounding when cp_size == 1. Then + # bump the last so the total is divisible by tp_size for sequence + # parallelism. + if cp_size > 1: + cp_mult = 2 * cp_size + padded_lens = [((L + cp_mult - 1) // cp_mult) * cp_mult for L in real_lens] + else: + padded_lens = list(real_lens) + if tp_size > 1: + rem = sum(padded_lens) % tp_size + if rem != 0: + padded_lens[-1] += tp_size - rem + + # Materialize the flat THD batch. + flat_emb_segs: list[Tensor] = [] + flat_lab_segs: list[Tensor] = [] + flat_pos_segs: list[Tensor] = [] + for emb, lab, l_real, l_pad in zip(seq_embs, seq_labs, real_lens, padded_lens): + flat_emb_segs.append(emb) + flat_lab_segs.append(lab) + flat_pos_segs.append(torch.arange(l_real, dtype=torch.long, device=device)) + pad_n = l_pad - l_real + if pad_n > 0: + flat_emb_segs.append(torch.zeros(pad_n, H, dtype=dtype, device=device)) + flat_lab_segs.append(torch.full((pad_n,), ignore_index, dtype=torch.long, device=device)) + flat_pos_segs.append(torch.arange(l_real, l_pad, dtype=torch.long, device=device)) + + inputs_embeds = torch.cat(flat_emb_segs, dim=0) # [T_total, H] + labels = torch.cat(flat_lab_segs, dim=0) # [T_total] + position_ids = torch.cat(flat_pos_segs, dim=0) # [T_total] + + seq_lens = torch.tensor(real_lens, dtype=torch.long, device=device).unsqueeze(-1) + seq_lens_padded = torch.tensor(padded_lens, dtype=torch.long, device=device).unsqueeze(-1) + cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.tensor(padded_lens, dtype=torch.int32, device=device).cumsum(0).to(torch.int32), + ] + ) + max_seqlen = torch.tensor(max(padded_lens), dtype=torch.int32, device=device) + + return { + "inputs_embeds": inputs_embeds, + "labels": labels, + "position_ids": position_ids, + "seq_lens": seq_lens, + "seq_lens_padded": seq_lens_padded, + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "qkv_format": "thd", + } + + +def _shard_packed_for_cp(packed: dict[str, Tensor], cp_mesh) -> dict[str, Tensor]: + """Partition a packed THD batch across CP ranks (TE's interleaved scheme). + + Mirrors ``nemo_automodel.components.distributed.cp_utils._shard_thd_chunk_for_te`` + but preserves the float dtype of ``inputs_embeds`` (the upstream helper + casts everything to int64, which would silently corrupt embeddings). + """ + import transformer_engine_torch as tex # local import — only needed when CP > 1 + + cp_size = cp_mesh.size() + cp_rank = torch.distributed.get_rank(group=cp_mesh.get_group()) + + cu_seqlens = packed["cu_seqlens"] + inputs_embeds = packed["inputs_embeds"] # [T, H] + labels = packed["labels"] # [T] + position_ids = packed["position_ids"] # [T] + + index = tex.thd_get_partitioned_indices(cu_seqlens, inputs_embeds.shape[0], cp_size, cp_rank) + inputs_embeds = inputs_embeds.index_select(0, index) + labels = labels.index_select(0, index) + position_ids = position_ids.index_select(0, index) + + return { + "inputs_embeds": inputs_embeds.contiguous(), + "labels": labels.to(torch.int64).contiguous(), + "position_ids": position_ids.to(torch.int64).contiguous(), + "cu_seqlens": cu_seqlens.to(torch.int32).contiguous(), + "max_seqlen": packed["max_seqlen"], + "qkv_format": "thd", + } + + +def prepare_packed_llm_inputs( + input_ids: Tensor, + text_embs: Tensor, + audio_embs: list[Tensor], + target_ids: Tensor, + padding_id: int, + placeholder_id: int, + device_mesh: Optional[Any] = None, +) -> dict[str, Any]: + """Pack a SALM minibatch and (optionally) shard it across CP ranks. + + Returns a dict with the same top-level keys produced by the BSHD branch of + ``SALMAutomodel.prepare_inputs`` plus an ``llm_kwargs`` dict carrying the + THD metadata to splat into ``self.llm(...)``:: + + { + "input_embeds": Tensor [T, H] (2D, no leading batch dim; + matches the canonical Automodel THD contract + produced by ``_shard_thd_chunk_for_te``), + "attention_mask": None, + "target_ids": Tensor [T], + "llm_kwargs": { + "qkv_format": "thd", + "cu_seqlens": Tensor [B+1] int32, + "position_ids": Tensor [T] int64, + "max_seqlen": int32 scalar, + }, + } + """ + from nemo.collections.speechlm2.parts.cp_helpers import get_cp_mesh + + cp_mesh, cp_size, _ = get_cp_mesh(device_mesh) + tp_size = 1 + if device_mesh is not None: + names = device_mesh.mesh_dim_names or () + if "tp" in names and device_mesh["tp"].size() > 1: + tp_size = device_mesh["tp"].size() + + packed = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=text_embs, + target_ids=target_ids, + replacements=audio_embs, + padding_id=padding_id, + placeholder_id=placeholder_id, + cp_size=cp_size, + tp_size=tp_size, + ) + + if cp_mesh is not None: + packed = _shard_packed_for_cp(packed, cp_mesh) + + return { + "input_embeds": packed["inputs_embeds"], + "attention_mask": None, + "target_ids": packed["labels"], + "llm_kwargs": { + "qkv_format": "thd", + # Match Automodel's standard THD contract (``thd_utils.process_input_for_thd`` + # and ``cp_utils._shard_thd_chunk_for_te``): emit only ``cu_seqlens`` (the + # padded cumsum) and a single ``max_seqlen``. Passing ``cu_seqlens_padded`` + # too would activate the ``pad_between_seqs=True`` branch in + # ``Automodel/.../attention/utils.py``, which routes TE down a different + # attention path. Passing pre-split ``max_seqlen_q`` / ``max_seqlen_kv`` + # gets them silently dropped by the preprocessor. + "cu_seqlens": packed["cu_seqlens"], + "position_ids": packed["position_ids"], + "max_seqlen": packed["max_seqlen"], + }, + } diff --git a/nemo/collections/speechlm2/parts/parallel.py b/nemo/collections/speechlm2/parts/parallel.py index 004e43e920dc..608e67b8528a 100644 --- a/nemo/collections/speechlm2/parts/parallel.py +++ b/nemo/collections/speechlm2/parts/parallel.py @@ -15,6 +15,7 @@ from __future__ import annotations import os +import warnings from datetime import timedelta from typing import Any, Dict, Optional @@ -25,6 +26,94 @@ from typing_extensions import override +# Blackwell sm_120, where TE 2.14's cuDNN fused-attention backward kernel +# silently amplifies THD/padding_causal gradients 8x-960x per layer. +_SM120 = (12, 0) + + +def validate_parallelism_compatibility( + *, + packed_sequences: bool, + cp_size: int, + attn_backend: str, + nvte_fused_attn: Optional[str], + device_capability: Optional[tuple[int, int]], +) -> None: + """Raise on known-incompatible SALMAutomodel configurations. + + Catches three combinations that produce silent NaN gradients or + hangs at training time: + + 1. ``packed_sequences=False`` (BSHD) under ``cp_size > 1``: TE's + fused-attention CP path rejects ``padding_causal``, so the + right-pad mask must be dropped. With the mask dropped pad K/V + leak into real-token attention through the causal-only mask and + gradients become NaN after step 1. No supported workaround; + must use the THD path. + 2. ``packed_sequences=True`` (THD) with ``attn != "te"``: the THD + packing emits a 2D ``[T_total, H]`` layout via TE's + ``thd_get_partitioned_indices`` and feeds TE varlen + FlashAttention. SDPA's 3D-THD path is broken in the Automodel + branch we depend on (transpose assumes 4D BSHD). + 3. ``packed_sequences=True`` + ``attn="te"`` + + ``NVTE_FUSED_ATTN != "0"``: TE 2.14's cuDNN fused-attention + backward kernel produces forward outputs that match FA bit-for-bit + but a backward that amplifies gradients 8x-960x per layer on + Blackwell sm_120. Compounded across the LLM's attention stack + this drives gradients to ``inf`` and the optimizer to NaN. Set + ``NVTE_FUSED_ATTN=0`` in the launcher environment to force + FlashAttention dispatch. + + Hard error on (1), (2), and (3)-on-sm_120; ``warnings.warn`` on + (3) for other architectures (the bug may not apply but we have no + way to be certain). + + Pure function — no side effects on globals or environment, so it + can be unit-tested with synthetic inputs. Called from + :meth:`SALMAutomodel.on_fit_start` once the device mesh is wired + up. + """ + # Case 1: BSHD + CP > 1 — hard incompatibility. + if not packed_sequences and cp_size > 1: + raise ValueError( + "SALMAutomodel: BSHD (model.packed_sequences=false) is incompatible " + f"with cp_size > 1 (got cp_size={cp_size}). TE's fused-attention CP path " + "rejects ``padding_causal``, so the right-pad mask is dropped before the " + "LLM, which lets pad K/V leak into real-token attention through the " + "causal mask and produces NaN gradients after step 1. " + "Set ``model.packed_sequences: true`` to use the THD path under CP " + "(see docs/source/speechlm2/training_and_scaling.rst)." + ) + + if packed_sequences: + # Case 2: THD path requires TE attention (SDPA THD is broken upstream). + if attn_backend != "te": + raise ValueError( + "SALMAutomodel: THD (model.packed_sequences=true) requires " + "``model.automodel_backend.attn=te``; " + f"got ``attn={attn_backend!r}``. SDPA's THD code path in the " + "Automodel branch transposes assuming 4D BSHD inputs and breaks " + "for the 2D [T_total, H] THD layout." + ) + + # Case 3: THD + TE attention without NVTE_FUSED_ATTN=0. + if nvte_fused_attn != "0": + msg = ( + "SALMAutomodel: ``packed_sequences=true`` with ``attn=te`` and " + "``NVTE_FUSED_ATTN`` not set to ``\"0\"`` (got " + f"{nvte_fused_attn!r}). TE 2.14's cuDNN fused-attention " + "backward kernel amplifies THD/padding_causal gradients " + "8x-960x per layer on Blackwell sm_120; the resulting ``inf`` " + "gradients drive the optimizer to NaN. Set " + "``NVTE_FUSED_ATTN=0`` in the launcher environment to force " + "FlashAttention dispatch (requires ``flash-attn`` installed " + "for your GPU arch)." + ) + if device_capability == _SM120: + raise ValueError(msg) + warnings.warn(msg, stacklevel=2) + + def setup_distributed( tp_size: int = 1, pp_size: int = 1, diff --git a/tests/collections/common/data/lhotse/test_broadcasting.py b/tests/collections/common/data/lhotse/test_broadcasting.py new file mode 100644 index 000000000000..45bc589ed10e --- /dev/null +++ b/tests/collections/common/data/lhotse/test_broadcasting.py @@ -0,0 +1,255 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for nemo/collections/common/data/lhotse/broadcasting.py. + +Fake-mesh tests run on CPU without a real distributed group — they +exercise the noop short-circuits and the rank-coordinate logic. The +gloo-based multiprocess tests verify the broadcast contract end-to-end +on a 2-rank CPU group. +""" +from __future__ import annotations + +import os +import socket +from typing import Any + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from nemo.collections.common.data.lhotse.broadcasting import BroadcastingDataLoader, broadcast_batch, is_dp_source_rank + +# --------------------------------------------------------------------------- +# Fake-mesh CPU-only tests (no distributed required). +# --------------------------------------------------------------------------- + + +class _FakeAxis: + def __init__(self, size: int, local_rank: int): + self._size = size + self._local_rank = local_rank + + def size(self) -> int: + return self._size + + def get_local_rank(self) -> int: + return self._local_rank + + +class _FakeMesh: + """Minimal DeviceMesh stand-in covering ``mesh_dim_names`` + ``__getitem__``.""" + + def __init__(self, sizes: dict[str, int], coords: dict[str, int]): + assert sizes.keys() == coords.keys() + self.mesh_dim_names = tuple(sizes.keys()) + self._sizes = sizes + self._coords = coords + + def __getitem__(self, name): + if isinstance(name, tuple): + raise NotImplementedError("multi-axis slicing not needed for fake-mesh tests") + return _FakeAxis(self._sizes[name], self._coords[name]) + + +def test_is_dp_source_rank_none_mesh(): + assert is_dp_source_rank(None) is True + + +def test_is_dp_source_rank_all_axes_size_one(): + mesh = _FakeMesh({"cp": 1, "tp": 1}, {"cp": 0, "tp": 0}) + assert is_dp_source_rank(mesh) is True + + +def test_is_dp_source_rank_no_relevant_axes(): + mesh = _FakeMesh({"dp": 2}, {"dp": 1}) + assert is_dp_source_rank(mesh) is True + + +@pytest.mark.parametrize( + "coords, expected", + [ + ({"cp": 0, "tp": 0}, True), + ({"cp": 1, "tp": 0}, False), + ({"cp": 0, "tp": 1}, False), + ({"cp": 1, "tp": 1}, False), + ], +) +def test_is_dp_source_rank_cp_tp_grid(coords, expected): + mesh = _FakeMesh({"cp": 2, "tp": 2}, coords) + assert is_dp_source_rank(mesh) is expected + + +def test_is_dp_source_rank_only_cp_axis(): + mesh = _FakeMesh({"cp": 4}, {"cp": 3}) + assert is_dp_source_rank(mesh) is False + + +def test_broadcast_batch_noop_returns_input(): + payload = {"x": torch.arange(4)} + out = broadcast_batch(payload, None) + assert out is payload + + +def test_broadcast_batch_noop_when_axes_size_one(): + mesh = _FakeMesh({"cp": 1, "tp": 1}, {"cp": 0, "tp": 0}) + payload = "anything" + assert broadcast_batch(payload, mesh) is payload + + +def test_broadcasting_dataloader_noop_iterates_source(): + real = [{"i": i} for i in range(4)] + loader = BroadcastingDataLoader(source=real, device_mesh=None) + assert list(loader) == real + + +def test_broadcasting_dataloader_noop_with_no_source_is_empty(): + loader = BroadcastingDataLoader(source=None, device_mesh=None) + assert list(loader) == [] + + +def test_broadcasting_dataloader_noop_state_dict_passthrough(): + class _Stateful: + def state_dict(self): + return {"cursor": 5} + + def load_state_dict(self, sd): + self._restored = sd + + def __iter__(self): + return iter([]) + + src = _Stateful() + loader = BroadcastingDataLoader(source=src, device_mesh=None) + assert loader.state_dict() == {"cursor": 5} + loader.load_state_dict({"cursor": 10}) + assert src._restored == {"cursor": 10} + + +def test_broadcasting_dataloader_state_dict_empty_when_source_lacks_method(): + loader = BroadcastingDataLoader(source=[1, 2, 3], device_mesh=None) + assert loader.state_dict() == {} + loader.load_state_dict({"anything": 1}) # must not raise + + +def test_broadcasting_dataloader_passes_through_len_when_available(): + loader = BroadcastingDataLoader(source=[1, 2, 3, 4, 5], device_mesh=None) + assert len(loader) == 5 + + +def test_broadcasting_dataloader_len_raises_when_source_has_no_len(): + class _NoLen: + def __iter__(self): + return iter([]) + + loader = BroadcastingDataLoader(source=_NoLen(), device_mesh=None) + with pytest.raises(TypeError): + len(loader) + + +# --------------------------------------------------------------------------- +# Distributed (gloo) end-to-end tests for the broadcast contract. +# --------------------------------------------------------------------------- + + +def _get_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def _init_gloo(rank: int, world_size: int, port: int) -> None: + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["RANK"] = str(rank) + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + + +def _build_cp_mesh(world_size: int): + return torch.distributed.device_mesh.init_device_mesh( + device_type="cpu", + mesh_shape=(world_size,), + mesh_dim_names=("cp",), + ) + + +def _broadcast_batch_worker(rank: int, world_size: int, port: int, queue: mp.Queue) -> None: + try: + _init_gloo(rank, world_size, port) + mesh = _build_cp_mesh(world_size) + if is_dp_source_rank(mesh): + payload: Any = {"tensor": torch.arange(8), "name": "hello"} + else: + payload = None + result = broadcast_batch(payload, mesh) + if isinstance(result, dict): + queue.put(("ok", result["tensor"].tolist(), result["name"])) + else: + queue.put(("ok", None, None)) + except Exception as e: + queue.put(("err", repr(e), None)) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _broadcasting_loader_worker(rank: int, world_size: int, port: int, queue: mp.Queue) -> None: + try: + _init_gloo(rank, world_size, port) + mesh = _build_cp_mesh(world_size) + source = [{"i": i} for i in range(3)] if is_dp_source_rank(mesh) else None + loader = BroadcastingDataLoader(source=source, device_mesh=mesh) + received = [batch["i"] for batch in loader] + queue.put(("ok", received)) + except Exception as e: + queue.put(("err", repr(e))) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _spawn_workers(target, world_size: int) -> list: + ctx = mp.get_context("spawn") + queue = ctx.Queue() + port = _get_free_port() + procs = [ctx.Process(target=target, args=(rank, world_size, port, queue)) for rank in range(world_size)] + for p in procs: + p.start() + for p in procs: + p.join(timeout=120) + results = [] + while not queue.empty(): + results.append(queue.get()) + for p in procs: + if p.exitcode != 0 and p.is_alive(): + p.terminate() + return results + + +def test_broadcast_batch_dispatches_payload_across_ranks(): + results = _spawn_workers(_broadcast_batch_worker, world_size=2) + assert len(results) == 2, results + for status, tensor_list, name in results: + assert status == "ok", results + assert tensor_list == list(range(8)) + assert name == "hello" + + +def test_broadcasting_dataloader_iterates_in_lockstep_across_ranks(): + results = _spawn_workers(_broadcasting_loader_worker, world_size=2) + assert len(results) == 2, results + for status, received in results: + assert status == "ok", results + assert received == [0, 1, 2] diff --git a/tests/collections/speechlm2/test_datamodule.py b/tests/collections/speechlm2/test_datamodule.py index c5438c6616d5..253ded844aa1 100644 --- a/tests/collections/speechlm2/test_datamodule.py +++ b/tests/collections/speechlm2/test_datamodule.py @@ -18,6 +18,7 @@ from lightning.pytorch.utilities import CombinedLoader from omegaconf import DictConfig +from nemo.collections.common.data.lhotse.broadcasting import BroadcastingDataLoader from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model from nemo.collections.speechlm2.data import DataModule @@ -90,7 +91,7 @@ def __getitem__(self, item): def test_datamodule_train_dataloader(data_config, tokenizer): data = DataModule(data_config, tokenizer=tokenizer, dataset=Identity()) dl = data.train_dataloader() - assert isinstance(dl, torch.utils.data.DataLoader) + assert isinstance(dl, (BroadcastingDataLoader, torch.utils.data.DataLoader)) dli = iter(dl) batch = next(dli) diff --git a/tests/collections/speechlm2/test_salm_cp_helpers.py b/tests/collections/speechlm2/test_salm_cp_helpers.py new file mode 100644 index 000000000000..62f57d47b523 --- /dev/null +++ b/tests/collections/speechlm2/test_salm_cp_helpers.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CPU-only tests for the CP-helper module. + +The ``cp_size > 1`` path in ``encode_audio_with_cp_distribution`` requires +a real ``torch.distributed`` process group; it's exercised by the 2-GPU +smoke. These tests cover the fallback contracts that run on every machine +(``cp_mesh is None``, ``B_aud == 0``). +""" +import torch + +from nemo.collections.speechlm2.parts.cp_helpers import encode_audio_with_cp_distribution, get_cp_mesh + + +def test_get_cp_mesh_none(): + assert get_cp_mesh(None) == (None, 1, 0) + + +class _DummyCpDim: + """Stand-in for ``device_mesh['cp']`` whose ``.size()`` is 1 (CP inactive).""" + + def size(self): + return 1 + + +class _DummyDeviceMesh: + """Minimal ``DeviceMesh``-like object exposing only the bits ``get_cp_mesh`` reads.""" + + def __init__(self, cp_size: int = 1, has_cp: bool = True): + self.mesh_dim_names = ("dp", "cp", "tp") if has_cp else ("dp", "tp") + self._cp_size = cp_size + + def __getitem__(self, key): + if key == "cp": + + class _Dim: + def __init__(self, size): + self._size = size + + def size(self): + return self._size + + return _Dim(self._cp_size) + raise KeyError(key) + + +def test_get_cp_mesh_cp_size_one(): + assert get_cp_mesh(_DummyDeviceMesh(cp_size=1)) == (None, 1, 0) + + +def test_get_cp_mesh_no_cp_dim(): + assert get_cp_mesh(_DummyDeviceMesh(has_cp=False)) == (None, 1, 0) + + +class _PerceptionStub: + """Stand-in for ``self.perception``: returns a deterministic embedding per audio.""" + + def __init__(self, hidden_size: int = 4): + self.hidden_size = hidden_size + + def __call__(self, *, input_signal, input_signal_length): + # Pretend each audio of length L produces L // 2 frames of embeddings; + # encode the row index into the first column so we can verify ordering. + B, T = input_signal.shape + if B == 0: + return torch.zeros(0, 0, self.hidden_size, dtype=torch.float32), input_signal_length + # Frame count per row scales with audio_lens. + out_lens = (input_signal_length // 2).clamp(min=1) + max_out = int(out_lens.max().item()) + embs = torch.zeros(B, max_out, self.hidden_size, dtype=torch.float32) + for i in range(B): + embs[i, : int(out_lens[i].item()), 0] = float(i) # marker + return embs, out_lens + + +def test_encode_audio_no_cp_returns_unpadded_list(): + perception = _PerceptionStub(hidden_size=4) + audios = torch.zeros(3, 1600, dtype=torch.float32) + audio_lens = torch.tensor([800, 1200, 1600], dtype=torch.long) + embs = encode_audio_with_cp_distribution( + perception, + audios, + audio_lens, + chunk_size_seconds=None, + sampling_rate=16000, + cp_mesh=None, + ) + # 3 audios → 3 embedding tensors with row-specific lengths. + assert len(embs) == 3 + expected_lens = [400, 600, 800] + for i, e in enumerate(embs): + assert e.shape == (expected_lens[i], 4) + # Marker preserved. + assert torch.all(e[:, 0] == float(i)) + + +def test_encode_audio_empty_batch_returns_empty(): + perception = _PerceptionStub() + audios = torch.zeros(0, 1600, dtype=torch.float32) + audio_lens = torch.zeros(0, dtype=torch.long) + embs = encode_audio_with_cp_distribution( + perception, + audios, + audio_lens, + chunk_size_seconds=None, + sampling_rate=16000, + cp_mesh=None, + ) + assert embs == [] diff --git a/tests/collections/speechlm2/test_salm_packed_sequences.py b/tests/collections/speechlm2/test_salm_packed_sequences.py new file mode 100644 index 000000000000..27ab57b5c3ef --- /dev/null +++ b/tests/collections/speechlm2/test_salm_packed_sequences.py @@ -0,0 +1,508 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from nemo.collections.speechlm2.parts.packed_sequences import pack_audio_into_text_embeds, prepare_packed_llm_inputs + +PAD = 0 +AUDIO = 100 + + +def _basic_batch(): + """Mirrors `test_audio_placeholders.py::test_replace_placeholders`. + + Two utterances, three audio replacements of length 4, 3, 2. + """ + input_ids = torch.tensor( + [ + [7, AUDIO, 1, 2, AUDIO, 1], + [PAD, PAD, 3, AUDIO, 4, 5], # left-padded + ] + ) + loss_mask = torch.tensor( + [ + [False, False, False, False, False, True], + [False, False, False, False, True, True], + ] + ) + embeds = torch.ones(2, 6, 2) + embeds[1, :2] = 0 # zero left-pad slots + replacements = [ + torch.full((4, 2), fill_value=2.0), + torch.full((3, 2), fill_value=3.0), + torch.full((2, 2), fill_value=4.0), + ] + target_ids = input_ids.where(loss_mask, -100) + return input_ids, embeds, target_ids, replacements + + +def test_basic_pack_shapes_and_cu_seqlens(): + input_ids, embeds, target_ids, replacements = _basic_batch() + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=replacements, + padding_id=PAD, + placeholder_id=AUDIO, + ) + # Real per-utterance flat lengths: + # utt0: 1 text + 4 audio + 2 text + 3 audio + 1 text = 11 + # utt1: 1 text + 2 audio + 2 text = 5 (left-pad of 2 stripped) + assert out["seq_lens"].squeeze(-1).tolist() == [11, 5] + # cp_size=1, tp_size=1 ⇒ no rounding + assert out["seq_lens_padded"].squeeze(-1).tolist() == [11, 5] + # cu_seqlens = [0] + cumsum(seq_lens_padded) + assert out["cu_seqlens"].dtype == torch.int32 + assert out["cu_seqlens"].tolist() == [0, 11, 16] + assert out["max_seqlen"].item() == 11 + assert out["qkv_format"] == "thd" + + T_total = 11 + 5 + assert out["inputs_embeds"].shape == (T_total, 2) + assert out["labels"].shape == (T_total,) + assert out["position_ids"].shape == (T_total,) + + +def test_position_ids_reset_per_utt(): + input_ids, embeds, target_ids, replacements = _basic_batch() + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=replacements, + padding_id=PAD, + placeholder_id=AUDIO, + ) + pos = out["position_ids"] + cu = out["cu_seqlens"].tolist() + for start, end in zip(cu[:-1], cu[1:]): + assert pos[start].item() == 0 + assert torch.equal(pos[start:end], torch.arange(end - start, dtype=torch.long)) + + +def test_audio_frame_labels_are_ignored(): + """Audio-frame slots must be -100 in `labels` regardless of loss_mask.""" + input_ids, embeds, target_ids, replacements = _basic_batch() + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=replacements, + padding_id=PAD, + placeholder_id=AUDIO, + ) + labels = out["labels"] + cu = out["cu_seqlens"].tolist() + # Utterance 0 layout: [t, a, a, a, a, t, t, a, a, a, t] + # pos 0 1 2 3 4 5 6 7 8 9 10 + # Audio slots before shift: 1..4 and 7..9. After per-utt next-token shift, + # the *previous* slot's label becomes the audio target → also ignored once + # the original slot's label was -100. Verify all original audio slots map + # to -100 in the shifted output (audio at t means lab_shift[t-1] gets + # what was at t, which is -100 from the audio fill). + utt0 = labels[cu[0] : cu[1]] + # Shifted: original audio at positions 1-4 → label[0..3] should be -100; + # audio at 7-9 → label[6..8] should be -100; last slot (10) is -100. + assert (utt0[0:4] == -100).all() + assert (utt0[6:9] == -100).all() + assert utt0[-1].item() == -100 + + +def test_labels_shifted_per_utt(): + """`labels[t]` should equal the original `target_ids` at position t+1 + *within the utterance*, with the last slot set to -100.""" + input_ids, embeds, target_ids, replacements = _basic_batch() + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=replacements, + padding_id=PAD, + placeholder_id=AUDIO, + ) + labels = out["labels"] + cu = out["cu_seqlens"].tolist() + # Utt 0 last slot is the trailing text "1" with loss_mask=True (target=1), + # which after per-utt shift becomes the label of position 9 (the last + # audio frame). Since the shift puts orig[t+1] at slot t, the second-to- + # last slot of utt0 holds target_ids of the trailing "1". + utt0 = labels[cu[0] : cu[1]] + assert utt0[-2].item() == 1 # trailing text token "1" with loss_mask=True + assert utt0[-1].item() == -100 # last position of every utterance is -100 + + # Utt 1 last two text tokens (4, 5) had loss_mask=True. After per-utt + # shift, label[L-3] = 4, label[L-2] = 5, label[L-1] = -100. + utt1 = labels[cu[1] : cu[2]] + assert utt1[-3].item() == 4 + assert utt1[-2].item() == 5 + assert utt1[-1].item() == -100 + + +def test_no_audio_utterance(): + """Utterance without any audio placeholders still packs correctly.""" + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + loss_mask = torch.tensor([[False, False, True, True, True]]) + embeds = torch.full((1, 5, 2), 1.0) + target_ids = input_ids.where(loss_mask, -100) + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=[], + padding_id=PAD, + placeholder_id=AUDIO, + ) + assert out["seq_lens"].squeeze(-1).tolist() == [5] + assert out["seq_lens_padded"].squeeze(-1).tolist() == [5] + assert out["cu_seqlens"].tolist() == [0, 5] + labels = out["labels"] + # Original target_ids = [-100, -100, 3, 4, 5]; after per-utt shift: + # [-100, 3, 4, 5, -100] + assert labels.tolist() == [-100, 3, 4, 5, -100] + + +def test_b_one(): + """Single-utterance batch produces valid `cu_seqlens=[0, L]`.""" + input_ids = torch.tensor([[1, AUDIO, 2]]) + loss_mask = torch.tensor([[False, False, True]]) + embeds = torch.full((1, 3, 2), 1.0) + target_ids = input_ids.where(loss_mask, -100) + replacements = [torch.full((3, 2), 7.0)] + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=replacements, + padding_id=PAD, + placeholder_id=AUDIO, + ) + assert out["seq_lens"].squeeze(-1).tolist() == [5] # 1 + 3 audio + 1 + assert out["cu_seqlens"].tolist() == [0, 5] + assert out["inputs_embeds"].shape == (5, 2) + + +@pytest.mark.parametrize("cp_size", [2, 4]) +def test_cp_divisibility(cp_size): + """Each per-utt padded length is a multiple of 2*cp_size.""" + input_ids, embeds, target_ids, replacements = _basic_batch() + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=replacements, + padding_id=PAD, + placeholder_id=AUDIO, + cp_size=cp_size, + ) + mult = 2 * cp_size + for L in out["seq_lens_padded"].squeeze(-1).tolist(): + assert L % mult == 0 + + +def test_tp_divisibility(): + """`T_total` is a multiple of tp_size (last utterance gets the bump).""" + input_ids, embeds, target_ids, replacements = _basic_batch() + # real_lens = [11, 5], total = 16; pick tp_size=3 to force a bump. + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=replacements, + padding_id=PAD, + placeholder_id=AUDIO, + tp_size=3, + ) + T_total = out["seq_lens_padded"].sum().item() + assert T_total % 3 == 0 + # First utterance untouched (only the last gets the TP bump). + assert out["seq_lens_padded"].squeeze(-1).tolist()[0] == 11 + # Last utterance bumped from 5 → 7 (next multiple of 3 after 16 is 18). + assert out["seq_lens_padded"].squeeze(-1).tolist()[1] == 7 + + +def test_tp_and_cp_combined(): + input_ids, embeds, target_ids, replacements = _basic_batch() + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=replacements, + padding_id=PAD, + placeholder_id=AUDIO, + cp_size=2, + tp_size=8, + ) + padded = out["seq_lens_padded"].squeeze(-1).tolist() + real = out["seq_lens"].squeeze(-1).tolist() + # Every padded length ≥ real and divisible by 4. + for r, p in zip(real, padded): + assert p >= r + assert p % 4 == 0 + # Total divisible by tp_size. + assert sum(padded) % 8 == 0 + + +def test_cu_seqlens_matches_padded_cumsum(): + input_ids, embeds, target_ids, replacements = _basic_batch() + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=replacements, + padding_id=PAD, + placeholder_id=AUDIO, + cp_size=2, + tp_size=8, + ) + expected = [0] + for L in out["seq_lens_padded"].squeeze(-1).tolist(): + expected.append(expected[-1] + L) + assert out["cu_seqlens"].tolist() == expected + assert out["max_seqlen"].item() == max(out["seq_lens_padded"].squeeze(-1).tolist()) + + +def test_loss_mask_propagates_to_minus_100(): + """Positions where loss_mask=False end up as -100 in the shifted labels.""" + input_ids = torch.tensor([[1, 2, 3, 4]]) + loss_mask = torch.tensor([[False, False, False, False]]) # nothing supervised + embeds = torch.full((1, 4, 2), 1.0) + target_ids = input_ids.where(loss_mask, -100) + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=[], + padding_id=PAD, + placeholder_id=AUDIO, + ) + assert (out["labels"] == -100).all() + + +def test_prepare_packed_llm_inputs_attention_kwargs_reach_te_preprocessor(): + """End-to-end contract: ``prepare_packed_llm_inputs`` → Automodel's TE + attention preprocessor must yield ``max_seqlen_q``/``max_seqlen_kv`` and + ``cu_seqlens_q``/``cu_seqlens_kv`` populated for the THD path. + + Regression for: ``prepare_packed_llm_inputs`` previously emitted + ``max_seqlen_q``/``max_seqlen_kv`` (plural), but the preprocessor only + inspects the singular ``max_seqlen`` key in the ``cu_seqlens`` branch + (``Automodel/.../attention/utils.py``). The plural keys were silently + dropped, TE fell back to a degenerate varlen path, and step-1 backward + produced NaN gradients that poisoned every subsequent step. + """ + automodel_attn = pytest.importorskip( + "nemo_automodel.components.attention.utils", + reason="Automodel attention preprocessor required for the contract test.", + ) + input_ids, embeds, target_ids, replacements = _basic_batch() + out = prepare_packed_llm_inputs( + input_ids=input_ids, + text_embs=embeds, + audio_embs=replacements, + target_ids=target_ids, + padding_id=PAD, + placeholder_id=AUDIO, + device_mesh=None, # CP=1, TP=1 path + ) + llm_kwargs = out["llm_kwargs"] + assert out["attention_mask"] is None + assert llm_kwargs["qkv_format"] == "thd" + assert "cu_seqlens" in llm_kwargs + assert "max_seqlen" in llm_kwargs, ( + "Automodel's preprocessor only checks the singular `max_seqlen` key in the " + "cu_seqlens THD branch; pre-split `max_seqlen_q`/`max_seqlen_kv` would be dropped." + ) + assert "cu_seqlens_padded" not in llm_kwargs, ( + "Standard Automodel pipeline emits only `cu_seqlens`, never both `cu_seqlens` " + "and `cu_seqlens_padded`. Passing both activates the `pad_between_seqs=True` " + "branch in Automodel/.../attention/utils.py, routing TE down a different path." + ) + + # Run the LLM kwargs through the preprocessor exactly as the attention + # layer does: 4D BSHD-shaped Q/K/V plus attention_mask=None plus the + # llm_kwargs splatted in. ``input_embeds`` is now 2D ``[T, H]`` per the + # canonical Automodel THD shape contract. + B, T, H = 1, int(out["input_embeds"].shape[0]), 2 + nh, hd = 2, 4 + q = torch.zeros(B, T, nh, hd) + k = torch.zeros(B, T, nh, hd) + v = torch.zeros(B, T, nh, hd) + _, _, _, te_attn_kwargs = automodel_attn.preprocess_args_and_kwargs_for_attn( + q, k, v, attention_mask=None, attn_impl="te", **llm_kwargs + ) + assert te_attn_kwargs.get("qkv_format") == "thd" + assert te_attn_kwargs.get("attn_mask_type") == "padding_causal" + assert "cu_seqlens_q" in te_attn_kwargs and "cu_seqlens_kv" in te_attn_kwargs + assert "max_seqlen_q" in te_attn_kwargs and "max_seqlen_kv" in te_attn_kwargs, ( + "TE DotProductAttention requires max_seqlen_q/kv for qkv_format='thd'; " + "missing keys cause silent degenerate-path fallback and NaN gradients." + ) + assert te_attn_kwargs["max_seqlen_q"] == llm_kwargs["max_seqlen"] + assert te_attn_kwargs["max_seqlen_kv"] == llm_kwargs["max_seqlen"] + + +def _bshd_supervised_pairs(input_ids, embeds, target_ids, replacements): + """Run the BSHD path (``replace_placeholders_and_build_targets`` + the + ``[:-1] / [1:]`` next-token shift used in + ``SALMAutomodel.prepare_inputs``) and return the ordered list of + supervised ``(input_embedding, target_token_id)`` pairs. + """ + from nemo.collections.speechlm2.models.salm import replace_placeholders_and_build_targets + + bshd_embs, bshd_targets, _ = replace_placeholders_and_build_targets( + input_ids=input_ids, + embeds=embeds, + padding_id=PAD, + placeholder_id=AUDIO, + replacements=[r.clone() for r in replacements], + target_ids=target_ids, + ) + bshd_embs = bshd_embs[:, :-1] + bshd_targets = bshd_targets[:, 1:] + + pairs = [] + B, T = bshd_targets.shape + for b in range(B): + for t in range(T): + tgt = bshd_targets[b, t].item() + if tgt != -100: + pairs.append((bshd_embs[b, t].clone(), tgt)) + return pairs + + +def _thd_supervised_pairs(input_ids, embeds, target_ids, replacements): + """Run the THD path (``pack_audio_into_text_embeds`` with the per-utt + next-token shift) and return the ordered list of supervised + ``(input_embedding, target_token_id)`` pairs. + """ + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=[r.clone() for r in replacements], + padding_id=PAD, + placeholder_id=AUDIO, + ) + embs = out["inputs_embeds"] # [T_total, H] + labels = out["labels"] # [T_total] + pairs = [] + for t in range(labels.shape[0]): + tgt = labels[t].item() + if tgt != -100: + pairs.append((embs[t].clone(), tgt)) + return pairs + + +def _assert_pairs_equivalent(bshd_pairs, thd_pairs, *, atol=1e-6): + assert len(bshd_pairs) == len(thd_pairs), ( + f"BSHD has {len(bshd_pairs)} supervised pairs, THD has {len(thd_pairs)}. " + f"Both paths must yield the same ordered set of (input, target) pairs." + ) + for i, ((e_b, t_b), (e_t, t_t)) in enumerate(zip(bshd_pairs, thd_pairs)): + assert t_b == t_t, ( + f"Pair {i}: target_id mismatch — BSHD={t_b}, THD={t_t}. " + f"Per-utt next-token shift must align with global [:-1]/[1:] shift." + ) + assert torch.allclose(e_b, e_t, atol=atol), ( + f"Pair {i} (target={t_b}): input embedding mismatch between BSHD and THD. " f"BSHD={e_b}, THD={e_t}" + ) + + +def test_thd_and_bshd_supervised_pairs_match_basic(): + """First-principles invariant: BSHD and THD are different *layouts* of the + same data. The set of supervised ``(input_embedding, target_token_id)`` + pairs that contribute to the cross-entropy loss must be identical between + paths. Any divergence in this set means the THD path is feeding the model + something the BSHD path is not (or vice-versa). + """ + input_ids, embeds, target_ids, replacements = _basic_batch() + bshd_pairs = _bshd_supervised_pairs(input_ids, embeds, target_ids, replacements) + thd_pairs = _thd_supervised_pairs(input_ids, embeds, target_ids, replacements) + _assert_pairs_equivalent(bshd_pairs, thd_pairs) + + +def test_thd_and_bshd_supervised_pairs_match_no_audio_utt(): + """Pure-text utterance (no audio_locator).""" + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + loss_mask = torch.tensor([[False, False, True, True, True]]) + embeds = torch.randn(1, 5, 4) + target_ids = input_ids.where(loss_mask, -100) + bshd_pairs = _bshd_supervised_pairs(input_ids, embeds, target_ids, replacements=[]) + thd_pairs = _thd_supervised_pairs(input_ids, embeds, target_ids, replacements=[]) + _assert_pairs_equivalent(bshd_pairs, thd_pairs) + + +def test_thd_and_bshd_supervised_pairs_match_left_padded(): + """Left-padded utterances must yield the same supervised pairs.""" + input_ids = torch.tensor( + [ + [PAD, PAD, PAD, 1, 2, AUDIO, 3, 4], + [PAD, 5, 6, AUDIO, 7, AUDIO, 8, 9], + ] + ) + loss_mask = torch.tensor( + [ + [False, False, False, False, False, True, True, True], + [False, False, False, True, True, True, True, True], + ] + ) + embeds = torch.randn(2, 8, 4) + embeds[0, :3] = 0 # zero left-pad slots + embeds[1, :1] = 0 + target_ids = input_ids.where(loss_mask, -100) + replacements = [ + torch.randn(3, 4), # utt0 audio + torch.randn(2, 4), # utt1 first audio + torch.randn(4, 4), # utt1 second audio + ] + bshd_pairs = _bshd_supervised_pairs(input_ids, embeds, target_ids, replacements) + thd_pairs = _thd_supervised_pairs(input_ids, embeds, target_ids, replacements) + _assert_pairs_equivalent(bshd_pairs, thd_pairs) + + +def test_thd_and_bshd_supervised_pairs_match_b1(): + """Single-utterance batch.""" + input_ids = torch.tensor([[1, AUDIO, 2, 3, AUDIO, 4]]) + loss_mask = torch.tensor([[False, False, True, True, True, True]]) + embeds = torch.randn(1, 6, 4) + target_ids = input_ids.where(loss_mask, -100) + replacements = [torch.randn(2, 4), torch.randn(5, 4)] + bshd_pairs = _bshd_supervised_pairs(input_ids, embeds, target_ids, replacements) + thd_pairs = _thd_supervised_pairs(input_ids, embeds, target_ids, replacements) + _assert_pairs_equivalent(bshd_pairs, thd_pairs) + + +def test_padded_slots_have_zero_embed_and_ignored_label(): + """Inter-utt padding (added for cp_size rounding) gets zero embedding, + -100 label, and contiguous position_ids.""" + input_ids, embeds, target_ids, replacements = _basic_batch() + out = pack_audio_into_text_embeds( + input_ids=input_ids, + embeds=embeds, + target_ids=target_ids, + replacements=replacements, + padding_id=PAD, + placeholder_id=AUDIO, + cp_size=2, # rounds 11→12 and 6→6 + ) + embs = out["inputs_embeds"] + labels = out["labels"] + pos = out["position_ids"] + # Utt 0 had real_len=11 and padded to 12 (next multiple of 4). Slot 11 is + # the pad slot. + assert torch.equal(embs[11], torch.zeros(2)) + assert labels[11].item() == -100 + assert pos[11].item() == 11 # contiguous with the utt's real positions diff --git a/tests/collections/speechlm2/test_salm_parallelism_validation.py b/tests/collections/speechlm2/test_salm_parallelism_validation.py new file mode 100644 index 000000000000..b2d3545eddfc --- /dev/null +++ b/tests/collections/speechlm2/test_salm_parallelism_validation.py @@ -0,0 +1,154 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for ``validate_parallelism_compatibility``. + +Pure-function tests — no Lightning, no model, no device mesh required. +""" +import warnings + +import pytest + +from nemo.collections.speechlm2.parts.parallel import validate_parallelism_compatibility + + +# Combinations that must pass without raising or warning. + + +def test_bshd_cp1_te_passes(): + validate_parallelism_compatibility( + packed_sequences=False, + cp_size=1, + attn_backend="te", + nvte_fused_attn=None, + device_capability=(9, 0), # H100 + ) + + +def test_bshd_cp1_sdpa_passes(): + validate_parallelism_compatibility( + packed_sequences=False, + cp_size=1, + attn_backend="sdpa", + nvte_fused_attn=None, + device_capability=(9, 0), + ) + + +def test_thd_cp1_te_with_fused_attn_off_passes(): + validate_parallelism_compatibility( + packed_sequences=True, + cp_size=1, + attn_backend="te", + nvte_fused_attn="0", + device_capability=(12, 0), # sm_120 — still fine because env is set + ) + + +def test_thd_cp2_te_with_fused_attn_off_passes(): + validate_parallelism_compatibility( + packed_sequences=True, + cp_size=2, + attn_backend="te", + nvte_fused_attn="0", + device_capability=(12, 0), + ) + + +# BSHD + CP > 1 — hard error regardless of other knobs. + + +@pytest.mark.parametrize("cp_size", [2, 4, 8]) +def test_bshd_with_cp_raises(cp_size): + with pytest.raises(ValueError, match="BSHD .* incompatible with cp_size > 1"): + validate_parallelism_compatibility( + packed_sequences=False, + cp_size=cp_size, + attn_backend="te", + nvte_fused_attn="0", + device_capability=(9, 0), + ) + + +# THD + non-TE attention — hard error. + + +@pytest.mark.parametrize("attn_backend", ["sdpa", "flex"]) +def test_thd_with_non_te_attn_raises(attn_backend): + with pytest.raises(ValueError, match=r"THD.*requires.*attn=te"): + validate_parallelism_compatibility( + packed_sequences=True, + cp_size=1, + attn_backend=attn_backend, + nvte_fused_attn="0", + device_capability=(9, 0), + ) + + +# THD + TE + NVTE_FUSED_ATTN unset — warns on non-sm_120, raises on sm_120. + + +@pytest.mark.parametrize("nvte_fused_attn", [None, "", "1", "true"]) +def test_thd_te_without_fused_attn_off_warns_on_other_archs(nvte_fused_attn): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + validate_parallelism_compatibility( + packed_sequences=True, + cp_size=1, + attn_backend="te", + nvte_fused_attn=nvte_fused_attn, + device_capability=(9, 0), # H100 + ) + assert len(caught) == 1 + assert "NVTE_FUSED_ATTN" in str(caught[0].message) + + +@pytest.mark.parametrize("nvte_fused_attn", [None, "", "1", "true"]) +def test_thd_te_without_fused_attn_off_raises_on_sm120(nvte_fused_attn): + with pytest.raises(ValueError, match="NVTE_FUSED_ATTN"): + validate_parallelism_compatibility( + packed_sequences=True, + cp_size=1, + attn_backend="te", + nvte_fused_attn=nvte_fused_attn, + device_capability=(12, 0), # sm_120 + ) + + +def test_thd_te_with_fused_attn_off_does_not_warn_on_sm120(): + """The escape-hatch case: user has the env set, no warning fires.""" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + validate_parallelism_compatibility( + packed_sequences=True, + cp_size=1, + attn_backend="te", + nvte_fused_attn="0", + device_capability=(12, 0), + ) + assert len(caught) == 0 + + +def test_unknown_device_capability_warns_not_raises(): + """``device_capability=None`` (CPU-only env) treats the THD/TE check as a warning.""" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + validate_parallelism_compatibility( + packed_sequences=True, + cp_size=1, + attn_backend="te", + nvte_fused_attn=None, + device_capability=None, + ) + assert len(caught) == 1 + assert "NVTE_FUSED_ATTN" in str(caught[0].message)