Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
)
from nemo.utils import logging


__all__ = ['ConformerEncoder', 'ConformerMultiLayerFeatureExtractor']


Expand Down Expand Up @@ -1065,15 +1066,24 @@ def setup_streaming_params(
else:
streaming_cfg.pre_encode_cache_size = 0

# Number of subsampled output frames produced from the pre-encode left-context
# cache. This is what we drop after subsampling so chunked inference matches a
# full pass. For convolutional subsampling with stride > 1, this is NOT a simple
# floor division — see ``ConvSubsampling.get_streaming_drop_size``.
if isinstance(streaming_cfg.pre_encode_cache_size, list):
if streaming_cfg.pre_encode_cache_size[1] >= 1:
streaming_cfg.drop_extra_pre_encoded = (
1 + (streaming_cfg.pre_encode_cache_size[1] - 1) // self.subsampling_factor
)
else:
streaming_cfg.drop_extra_pre_encoded = 0
pre_encode_cache = streaming_cfg.pre_encode_cache_size[1]
else:
pre_encode_cache = streaming_cfg.pre_encode_cache_size

if pre_encode_cache <= 0:
streaming_cfg.drop_extra_pre_encoded = 0
elif hasattr(self.pre_encode, "get_streaming_drop_size"):
streaming_cfg.drop_extra_pre_encoded = self.pre_encode.get_streaming_drop_size(pre_encode_cache)
else:
streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor
# Legacy fallback for custom pre_encode modules that pre-date
# ``get_streaming_drop_size``. Coincides with the convolutional recurrence at
# the default ``cache_size = subsampling_factor + 1`` but diverges otherwise.
streaming_cfg.drop_extra_pre_encoded = 1 + (pre_encode_cache - 1) // self.subsampling_factor

for m in self.layers.modules():
if hasattr(m, "_max_cache_len"):
Expand Down
36 changes: 36 additions & 0 deletions nemo/collections/asr/parts/submodules/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ def get_sampling_frames(self):
def get_streaming_cache_size(self):
return 0

def get_streaming_drop_size(self, cache_size: int) -> int:
"""Number of subsampled output frames produced from `cache_size` input frames.

Used by streaming encoders to know how many leading frames of the encoder output
correspond to the pre-encode left-context cache, so they can be dropped after
subsampling. For StackingSubsampling the relation is exact: subsampling_factor
consecutive frames stack into one output frame.
"""
if cache_size <= 0:
return 0
return cache_size // self.subsampling_factor

def forward(self, x, lengths):
b, t, h = x.size()
pad_size = (self.subsampling_factor - (t % self.subsampling_factor)) % self.subsampling_factor
Expand Down Expand Up @@ -382,6 +394,30 @@ def get_sampling_frames(self):
def get_streaming_cache_size(self):
return [0, self.subsampling_factor + 1]

def get_streaming_drop_size(self, cache_size: int) -> int:
"""Number of subsampled output frames produced from `cache_size` input frames.

For convolutional subsampling with stride > 1, the length transformation through
each layer is not a simple floor division: it follows the recurrence
`L_next = floor((L + all_paddings - kernel_size) / stride) + 1` (or `ceil` when
`_ceil_mode` is set). Composed over `_sampling_num` layers, the result is what
`calc_length` already computes for the actual forward pass. Using the same helper
here keeps the streaming-drop count consistent with the encoder's own length
bookkeeping for arbitrary `cache_size`, instead of a divisor approximation that
only happens to match the default `subsampling_factor + 1` cache size.
"""
if cache_size <= 0:
return 0
out = calc_length(
torch.tensor(cache_size, dtype=torch.float),
all_paddings=self._left_padding + self._right_padding,
kernel_size=self._kernel_size,
stride=self._stride,
ceil_mode=self._ceil_mode,
repeat_num=self._sampling_num,
)
return int(out.item())

def forward(self, x, lengths):
out_lengths = calc_length(
lengths,
Expand Down
98 changes: 97 additions & 1 deletion tests/collections/asr/test_asr_subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import pytest
import torch

from nemo.collections.asr.models import ASRModel


Expand Down Expand Up @@ -59,3 +58,100 @@ def test_forward(self):
assert diff <= 0.2
diff = torch.mean(torch.abs(logprobs_batch4_split - logprobs_batch4_nosplit))
assert diff <= 0.2


class TestStreamingDropExtraPreEncoded:
"""``ConvSubsampling.get_streaming_drop_size`` must match what the encoder actually
produces from a ``cache_size``-long input segment.

Regression test for the streaming/full-pass mismatch reported in
https://github.com/NVIDIA-NeMo/NeMo/issues/15482 — the old formula
``1 + (cache_size - 1) // subsampling_factor`` diverges from the true convolutional
recurrence for arbitrary ``pre_encode_cache_size``.
"""

@pytest.mark.unit
@pytest.mark.parametrize(
"subsampling,subsampling_factor",
[
("striding", 4),
("striding", 8),
("dw_striding", 4),
("dw_striding", 8),
],
)
@pytest.mark.parametrize("cache_size", [1, 4, 8, 9, 11, 16, 32])
def test_drop_size_matches_forward(self, subsampling, subsampling_factor, cache_size):
"""For a causal conv subsampling, the number of output frames the actual
``forward`` returns from a ``cache_size``-long input must equal
``get_streaming_drop_size(cache_size)``.
"""
from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling

feat_in = 80
sub = ConvSubsampling(
subsampling=subsampling,
subsampling_factor=subsampling_factor,
feat_in=feat_in,
feat_out=16,
conv_channels=16,
subsampling_conv_chunking_factor=1,
is_causal=True,
)
sub.eval()
x = torch.zeros(1, cache_size, feat_in)
lengths = torch.tensor([cache_size], dtype=torch.int64)
with torch.no_grad():
_, out_lengths = sub(x, lengths)
expected = int(out_lengths[0].item())
assert sub.get_streaming_drop_size(cache_size) == expected

@pytest.mark.unit
def test_drop_size_zero_for_empty_cache(self):
from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling, StackingSubsampling

sub = ConvSubsampling(
subsampling="striding",
subsampling_factor=8,
feat_in=80,
feat_out=16,
conv_channels=16,
subsampling_conv_chunking_factor=1,
is_causal=True,
)
assert sub.get_streaming_drop_size(0) == 0

stack = StackingSubsampling(subsampling_factor=4, feat_in=80, feat_out=16)
assert stack.get_streaming_drop_size(0) == 0

@pytest.mark.unit
def test_drop_size_legacy_formula_diverges_for_non_default_cache(self):
"""Document the bug being fixed: at the issue-reported case ``cache_size=11``
with ``subsampling_factor=8``, the old formula returns 2 but the true value is 3.
"""
from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling

sub = ConvSubsampling(
subsampling="striding",
subsampling_factor=8,
feat_in=80,
feat_out=16,
conv_channels=16,
subsampling_conv_chunking_factor=1,
is_causal=True,
)
cache_size = 11
legacy = 1 + (cache_size - 1) // 8
assert legacy == 2 # old, wrong
assert sub.get_streaming_drop_size(cache_size) == 3 # new, matches the forward pass

@pytest.mark.unit
def test_stacking_drop_size(self):
from nemo.collections.asr.parts.submodules.subsampling import StackingSubsampling

stack = StackingSubsampling(subsampling_factor=4, feat_in=80, feat_out=16)
# StackingSubsampling.get_streaming_cache_size() returns 0 by default, but the
# helper should still answer sensibly for any positive cache_size.
assert stack.get_streaming_drop_size(4) == 1
assert stack.get_streaming_drop_size(7) == 1
assert stack.get_streaming_drop_size(8) == 2
Loading