Skip to content

Add Transformer Encoder for ASR #15661

Open
nithinraok wants to merge 8 commits into
mainfrom
transformer_asr_pr
Open

Add Transformer Encoder for ASR #15661
nithinraok wants to merge 8 commits into
mainfrom
transformer_asr_pr

Conversation

@nithinraok
Copy link
Copy Markdown
Member

@nithinraok nithinraok commented May 4, 2026

Important

The Update branch button must only be pressed in very rare occassions.
An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.

What does this PR do ?

Adds a new ASR transformer encoder with frame stacking, configurable subsampling, FlexAttention-based full attention, and optional QK Norm

Collection: ASR

Changelog

  • Add TransformerEncoder in nemo/collections/asr/modules/transformer_encoder.py as a lightweight pre-norm transformer encoder for ASR.
  • Add TransformerEncoderConfig dataclass to capture encoder hyperparameters such as d_model, n_heads, n_layers, qkv_bias, qk_norm, and
    subsampling_factor.
  • Add FeatureStacking pre-encoder module to stack consecutive frames, reduce sequence length, and project stacked features into the model
    dimension.
  • Add transformer building blocks: FeedForward, MultiHeadAttention, and TransformerBlock.
  • Implement full-attention masking with PyTorch FlexAttention and padding-aware block masks.
  • Add optional per-head QK normalization before attention score computation.
  • Restrict the initial PR scope to attn_mode="full" and raise an error for unsupported future modes.
  • Export TransformerEncoder from nemo/collections/asr/modules/__init__.py for config-based instantiation.
  • Add example config in examples/asr/conf/fastconformer/transformer_stacking_tdt_bpe.yaml showing how to train an RNNT/TDT model with the new
    encoder.

Usage

  import torch

  from nemo.collections.asr.modules import TransformerEncoder

  encoder = TransformerEncoder(
      feat_in=128,
      d_model=512,
      n_heads=8,
      n_layers=12,
      drop_rate=0.1,
      qkv_bias=False,
      qk_norm=True,
      subsampling_factor=4,
      attn_mode="full",
  )

  audio_signal = torch.randn(2, 128, 400)   # (B, C, T)
  lengths = torch.tensor([400, 360])        # valid frame lengths

  encoded, encoded_lengths = encoder(audio_signal, lengths)

  print(encoded.shape)         # (B, D, T')
  print(encoded_lengths)       # ceil(length / subsampling_factor)

You can also instantiate it from Hydra config with:

  encoder:
    _target_: nemo.collections.asr.modules.transformer_encoder.TransformerEncoder
    feat_in: ${model.preprocessor.features}
    d_model: 1280
    n_heads: 16
    n_layers: 32
    drop_rate: 0.1
    qkv_bias: false
    qk_norm: true
    subsampling_factor: 8
    attn_mode: full

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

nithinraok added 2 commits May 4, 2026 08:37
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions github-actions Bot added the ASR label May 4, 2026
@nithinraok nithinraok requested a review from pzelasko May 4, 2026 15:51
Comment thread nemo/collections/asr/modules/transformer_encoder.py Outdated
Comment thread nemo/collections/asr/modules/transformer_encoder.py Outdated
Comment thread nemo/collections/asr/modules/transformer_encoder.py Outdated
Comment thread nemo/collections/asr/modules/transformer_encoder.py
Comment thread nemo/collections/asr/modules/transformer_encoder.py
Comment thread nemo/collections/asr/modules/transformer_encoder.py Outdated
Comment thread tests/collections/asr/test_transformer_encoder.py Outdated
Comment thread tests/collections/asr/test_transformer_encoder.py
Comment thread tests/collections/asr/test_transformer_encoder.py
Comment thread tests/collections/asr/test_transformer_encoder.py
Copy link
Copy Markdown
Collaborator

@pzelasko pzelasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice work!

def __init__(self, cfg: TransformerEncoderConfig):
super().__init__()
self.net = nn.Sequential(
nn.Linear(cfg.d_model, 4 * cfg.d_model),
Copy link
Copy Markdown
Collaborator

@tango4j tango4j May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never be hardcoded since this is very frequently customized variable.
cfg.ff_expansion should be added, replacing the hardcoded 4.

Also, I think we should add this function to TransformerEncoderConfig

@property
    def ff_hidden_size(self) -> int:
        return int(self.ff_expansion * self.d_model)

and make this line
nn.Linear(cfg.d_model, cfg.ff_hidden_size)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I donlt think adding property to dataclass is a good idea. however added the ff_exansion parameter

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need to add property function. But wrapping the variable with int() is necessary.
And ff_expansion should be float.
Maybe one line of a variable sanity check on the "ffn_hidden_size" would make the code easy to track errors.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes added now, pls check

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked the change and ff_hidden variable. Looks good.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please approve if all looks good from your end

Comment thread examples/asr/conf/fastconformer/transformer_stacking_tdt_bpe.yaml
Comment thread nemo/collections/asr/modules/transformer_encoder.py
Comment thread nemo/collections/asr/modules/transformer_encoder.py
Comment thread nemo/collections/asr/modules/transformer_encoder.py
Comment thread nemo/collections/asr/modules/transformer_encoder.py
Comment thread nemo/collections/asr/modules/transformer_encoder.py
Comment thread nemo/collections/asr/modules/transformer_encoder.py Outdated
nithinraok added 2 commits May 5, 2026 00:49
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
@nithinraok
Copy link
Copy Markdown
Member Author

/ok to test ef9547a

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 5, 2026

[🤖]: Hi @nithinraok 👋,

We wanted to let you know that a CICD pipeline for this PR just finished successfully.

So it might be time to merge this PR or get some approvals.

nithinraok added 2 commits May 5, 2026 12:14
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Copy link
Copy Markdown
Collaborator

@pzelasko pzelasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @nithinraok ! It looks good from my side, let's wait for @tango4j to approve and LGTM

@tango4j
Copy link
Copy Markdown
Collaborator

tango4j commented May 6, 2026

This doesn't need to be happening in this PR but to ensure standardization, but later on, we probably need to verify this implementation is 100% compatible (transferable weights and identical results) with standard baselines like the Hugging Face Transformer.

If there are any custom scaling or structural tweaks, it would be great to make them optional flags. This keeps our baseline universally compatible while still allowing for optimizations when needed.

I think we should merge this after getting approval from @KunalDhawan and @stevehuang52 too.

Comment thread nemo/collections/asr/modules/transformer_encoder.py
Copy link
Copy Markdown
Collaborator

@KunalDhawan KunalDhawan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, Nithin! See minor comments below, other than that LGTM


import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add import guard for flex_attention so that people using older PyTorch (<2.5) don't run into unexpected import errors.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minimum version of pytorch for nemo is 2.6 and current stable pytorch version is 2.11, so I don't think we should add a import guard.

return pad_mask


class FeatureStacking(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FeatureStacking class has a lot of overlaps with StackingSubsampling in nemo/collections/asr/parts/submodules/subsampling.py (i see a difference with length update where you uses individual sample length, not batch-padded length). Maybe it would be better to extend StackingSubsampling with the changes required and re-use the same?

Copy link
Copy Markdown
Member Author

@nithinraok nithinraok May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

I think its better to keep this as base and update in that instead. However I feel the key difference in length computation makes it harder to merge them.

Also StackingSubsampling has get_sampling_frames() and get_streaming_cache_size() for the Conformer streaming pipeline, which doesn;t apply here yet.

return x


class TransformerEncoder(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here shouldn't TransformerEncoder inherit from NeuralModule and Exportable like other ASR encoders? This would enable @typecheck, avoid integration issues with save_to/restore_from and from_pretrained for which NeuralModule provides serialization hooks.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I wanted someone to ask this question :D

save_to and restore_from would work based on Model Class (ex: EncDecCTCBPE) . From next release, we're hoping to be moving away from requiring NeuralModule for new modules as we see little benefit for amount of code/work to support.

I agree with your point on Exportable, but here its tightly couple with NeuralModule for input and output types, which we need to come up with a better way for modules like this. CC: @pzelasko

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, this model definition is a testing ground for modern re-write / refactoring of NeMo Speech, it is OK to not re-use similar components to have an easier time developing.

I am thinking maybe we need to sprinkle some @experimental tags across the codebase to indicate new modules for which the APIs are likely to change until they are stabilized. This is a good candidate (we can follow up in a separate PR before release code cut off)

class TransformerEncoder(nn.Module):
"""Pre-norm Transformer encoder for ASR.

Architecture: FeatureStacking -> EmbedScale -> LayerNorm -> N x TransformerBlock -> FinalNorm
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update docstring, EmbedScale was removed in commit d277c3f

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

x = self.embed_norm(x)

B, T, _ = x.shape
block_mask = create_block_mask(_make_padding_mod(length), B=B, H=1, Q_LEN=T, KV_LEN=T, device=x.device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non blocker, but here we are recomputing block_mask in every forward pass. This might be wasted compute when length and T are unchanged across micro-batches, and especially for inference. Maybe worth caching this.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. In practice T and length change every batch during training so a cache wouldn't help. During inference the cost of create_block_mask is small relative to the attention layers themselves IMO.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember playing with flex_attention >1 year ago and back then it required _compile=True to be efficient. But since Nithin is reporting good throughput numbers with this code, they may have developed and optimized the internals more (especially since flex_attention is compiled with dynamic=True and has no issues which back then was also not true).

def __init__(self, cfg: TransformerEncoderConfig):
super().__init__()
self.n_heads = cfg.n_heads
self.head_dim = cfg.d_model // cfg.n_heads
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth adding a warning here when d_model % n_heads != 0. This would lead to silent integer truncation if user picks incompatible values (e.g., d_model=512, n_heads=6 → head_dim=85, but qkv reshape will misalign)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gooid point, rather than warning, I think we should raise error. updated.

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Copy link
Copy Markdown
Collaborator

@tango4j tango4j left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Request to add
pre_block_norm: bool = True
option for "embed_norm" line.

This can make TF encoder implementation achieve parity with old style (e.g., Whisper) TF encoder implementations, and weights are compatible with those models.

Comment thread nemo/collections/asr/modules/transformer_encoder.py Outdated
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
@nithinraok
Copy link
Copy Markdown
Member Author

/ok to test c212c41

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 7, 2026

[🤖]: Hi @nithinraok 👋,

We wanted to let you know that a CICD pipeline for this PR just finished successfully.

So it might be time to merge this PR or get some approvals.

Copy link
Copy Markdown
Collaborator

@tango4j tango4j left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the issues that I had have been resolved.
@nithinraok I appreciate the contribution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants