Add Transformer Encoder for ASR #15661
Conversation
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
| def __init__(self, cfg: TransformerEncoderConfig): | ||
| super().__init__() | ||
| self.net = nn.Sequential( | ||
| nn.Linear(cfg.d_model, 4 * cfg.d_model), |
There was a problem hiding this comment.
This should never be hardcoded since this is very frequently customized variable.
cfg.ff_expansion should be added, replacing the hardcoded 4.
Also, I think we should add this function to TransformerEncoderConfig
@property
def ff_hidden_size(self) -> int:
return int(self.ff_expansion * self.d_model)
and make this line
nn.Linear(cfg.d_model, cfg.ff_hidden_size)
There was a problem hiding this comment.
I donlt think adding property to dataclass is a good idea. however added the ff_exansion parameter
There was a problem hiding this comment.
Don't need to add property function. But wrapping the variable with int() is necessary.
And ff_expansion should be float.
Maybe one line of a variable sanity check on the "ffn_hidden_size" would make the code easy to track errors.
There was a problem hiding this comment.
Yes added now, pls check
There was a problem hiding this comment.
Checked the change and ff_hidden variable. Looks good.
There was a problem hiding this comment.
Please approve if all looks good from your end
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
|
/ok to test ef9547a |
|
[🤖]: Hi @nithinraok 👋, We wanted to let you know that a CICD pipeline for this PR just finished successfully. So it might be time to merge this PR or get some approvals. |
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
pzelasko
left a comment
There was a problem hiding this comment.
Thanks @nithinraok ! It looks good from my side, let's wait for @tango4j to approve and LGTM
|
This doesn't need to be happening in this PR but to ensure standardization, but later on, we probably need to verify this implementation is 100% compatible (transferable weights and identical results) with standard baselines like the Hugging Face Transformer. If there are any custom scaling or structural tweaks, it would be great to make them optional flags. This keeps our baseline universally compatible while still allowing for optimizations when needed. I think we should merge this after getting approval from @KunalDhawan and @stevehuang52 too. |
KunalDhawan
left a comment
There was a problem hiding this comment.
Great work, Nithin! See minor comments below, other than that LGTM
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.nn.attention.flex_attention import create_block_mask, flex_attention |
There was a problem hiding this comment.
We should add import guard for flex_attention so that people using older PyTorch (<2.5) don't run into unexpected import errors.
There was a problem hiding this comment.
Minimum version of pytorch for nemo is 2.6 and current stable pytorch version is 2.11, so I don't think we should add a import guard.
| return pad_mask | ||
|
|
||
|
|
||
| class FeatureStacking(nn.Module): |
There was a problem hiding this comment.
FeatureStacking class has a lot of overlaps with StackingSubsampling in nemo/collections/asr/parts/submodules/subsampling.py (i see a difference with length update where you uses individual sample length, not batch-padded length). Maybe it would be better to extend StackingSubsampling with the changes required and re-use the same?
There was a problem hiding this comment.
Good point.
I think its better to keep this as base and update in that instead. However I feel the key difference in length computation makes it harder to merge them.
Also StackingSubsampling has get_sampling_frames() and get_streaming_cache_size() for the Conformer streaming pipeline, which doesn;t apply here yet.
| return x | ||
|
|
||
|
|
||
| class TransformerEncoder(nn.Module): |
There was a problem hiding this comment.
Here shouldn't TransformerEncoder inherit from NeuralModule and Exportable like other ASR encoders? This would enable @typecheck, avoid integration issues with save_to/restore_from and from_pretrained for which NeuralModule provides serialization hooks.
There was a problem hiding this comment.
Good question, I wanted someone to ask this question :D
save_to and restore_from would work based on Model Class (ex: EncDecCTCBPE) . From next release, we're hoping to be moving away from requiring NeuralModule for new modules as we see little benefit for amount of code/work to support.
I agree with your point on Exportable, but here its tightly couple with NeuralModule for input and output types, which we need to come up with a better way for modules like this. CC: @pzelasko
There was a problem hiding this comment.
+1, this model definition is a testing ground for modern re-write / refactoring of NeMo Speech, it is OK to not re-use similar components to have an easier time developing.
I am thinking maybe we need to sprinkle some @experimental tags across the codebase to indicate new modules for which the APIs are likely to change until they are stabilized. This is a good candidate (we can follow up in a separate PR before release code cut off)
| class TransformerEncoder(nn.Module): | ||
| """Pre-norm Transformer encoder for ASR. | ||
|
|
||
| Architecture: FeatureStacking -> EmbedScale -> LayerNorm -> N x TransformerBlock -> FinalNorm |
There was a problem hiding this comment.
Need to update docstring, EmbedScale was removed in commit d277c3f
| x = self.embed_norm(x) | ||
|
|
||
| B, T, _ = x.shape | ||
| block_mask = create_block_mask(_make_padding_mod(length), B=B, H=1, Q_LEN=T, KV_LEN=T, device=x.device) |
There was a problem hiding this comment.
Non blocker, but here we are recomputing block_mask in every forward pass. This might be wasted compute when length and T are unchanged across micro-batches, and especially for inference. Maybe worth caching this.
There was a problem hiding this comment.
Good catch. In practice T and length change every batch during training so a cache wouldn't help. During inference the cost of create_block_mask is small relative to the attention layers themselves IMO.
There was a problem hiding this comment.
I remember playing with flex_attention >1 year ago and back then it required _compile=True to be efficient. But since Nithin is reporting good throughput numbers with this code, they may have developed and optimized the internals more (especially since flex_attention is compiled with dynamic=True and has no issues which back then was also not true).
| def __init__(self, cfg: TransformerEncoderConfig): | ||
| super().__init__() | ||
| self.n_heads = cfg.n_heads | ||
| self.head_dim = cfg.d_model // cfg.n_heads |
There was a problem hiding this comment.
Maybe worth adding a warning here when d_model % n_heads != 0. This would lead to silent integer truncation if user picks incompatible values (e.g., d_model=512, n_heads=6 → head_dim=85, but qkv reshape will misalign)
There was a problem hiding this comment.
gooid point, rather than warning, I think we should raise error. updated.
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
tango4j
left a comment
There was a problem hiding this comment.
Request to add
pre_block_norm: bool = True
option for "embed_norm" line.
This can make TF encoder implementation achieve parity with old style (e.g., Whisper) TF encoder implementations, and weights are compatible with those models.
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
|
/ok to test c212c41 |
|
[🤖]: Hi @nithinraok 👋, We wanted to let you know that a CICD pipeline for this PR just finished successfully. So it might be time to merge this PR or get some approvals. |
tango4j
left a comment
There was a problem hiding this comment.
All the issues that I had have been resolved.
@nithinraok I appreciate the contribution.
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
Adds a new ASR transformer encoder with frame stacking, configurable subsampling, FlexAttention-based full attention, and optional QK Norm
Collection: ASR
Changelog
nemo/collections/asr/modules/transformer_encoder.pyas a lightweight pre-norm transformer encoder for ASR.TransformerEncoderConfigdataclass to capture encoder hyperparameters such as d_model, n_heads, n_layers, qkv_bias, qk_norm, andsubsampling_factor.
FeatureStackingpre-encoder module to stack consecutive frames, reduce sequence length, and project stacked features into the modeldimension.
FeedForward,MultiHeadAttention, andTransformerBlock.nemo/collections/asr/modules/__init__.pyfor config-based instantiation.examples/asr/conf/fastconformer/transformer_stacking_tdt_bpe.yamlshowing how to train an RNNT/TDT model with the newencoder.
Usage
You can also instantiate it from Hydra config with:
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information