Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __getattribute__(self, name: str) -> None:
logger.warning("WARNING: transformer_engine not installed. Using default recipe.")

try:
from subquadratic_ops.rearrange import rearrange as subquadratic_ops_rearrange
from subquadratic_ops_torch.rearrange import rearrange as subquadratic_ops_rearrange
except ImportError:

def subquadratic_ops_rearrange(*args, **kwargs):
Expand Down
27 changes: 13 additions & 14 deletions nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ def causal_conv1d_fn(*args, **kwargs):


try:
from subquadratic_ops.b2b_causal_conv1d import b2b_causal_conv1d
from subquadratic_ops.causal_conv1d import causal_conv1d
from subquadratic_ops.fft_causal_conv1d import fft_causal_conv1d
from subquadratic_ops.fft_causal_conv1d import short_fft_is_available as is_fused_supported
from subquadratic_ops.implicit_filter import implicit_filter
from subquadratic_ops_torch.b2b_causal_conv1d import b2b_causal_conv1d
from subquadratic_ops_torch.causal_conv1d import causal_conv1d
from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d
from subquadratic_ops_torch.implicit_filter import implicit_filter
except ImportError:

def causal_conv1d(*args, **kwargs):
Expand All @@ -73,10 +72,6 @@ def fft_causal_conv1d(*args, **kwargs):
"""Not imported: fft_causal_conv1d. An error will be raised if this is called."""
raise ImportError("subquadratic_ops not installed. fft_causal_conv1d is not available.")

def is_fused_supported(*args, **kwargs):
"""Not imported: is_fused_supported. An error will be raised if this is called."""
raise ImportError("subquadratic_ops not installed. is_fused_supported is not available.")

def implicit_filter(*args, **kwargs):
"""Not imported: implicit_filter. An error will be raised if this is called."""
raise ImportError("subquadratic_ops not installed. implicit_filter is not available.")
Expand Down Expand Up @@ -482,8 +477,6 @@ def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=Fal
# causal
else:
if use_subquadratic_ops:
if not is_fused_supported(k.shape[-1]): # TODO: Remove this check after full subquadratic_ops support
raise ValueError("subquadratic_ops FFT causal convolution is not supported for this filter length.")
y = fft_causal_conv1d(u, k.squeeze(0))
else:
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
Expand Down Expand Up @@ -1351,6 +1344,10 @@ class B2BCausalConv1dModule(nn.Module):
"""Module that performs back-to-back causal convolution operations using optimized CUDA kernels.

Combines the projection and mixer convolutions into a single optimized operation.

Note: This module stores references to other modules without registering them as child modules
to avoid duplicate parameters in the state dict. The actual parameters are owned by the
parent HyenaMixer module's hyena_proj_conv and mixer attributes.
"""

def __init__(
Expand All @@ -1366,9 +1363,11 @@ def __init__(
"""
super().__init__()
self.b2b_causal_conv1d_fn = b2b_causal_conv1d
# Store references to the modules, not their weights
self._proj_conv_module = proj_conv_module
self._mixer_module = mixer_module
# Store references to the modules WITHOUT registering them as child modules
# Using object.__setattr__ bypasses PyTorch's module registration system
# This prevents their parameters from appearing in the state dict with the b2b_kernel prefix
object.__setattr__(self, '_proj_conv_module', proj_conv_module)
object.__setattr__(self, '_mixer_module', mixer_module)
self._use_conv_bias = self._mixer_module.use_conv_bias
self.operator_type = operator_type

Expand Down
37 changes: 1 addition & 36 deletions tests/collections/llm/gpt/model/test_hyena_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,9 @@ def test_fftconv_func_high_dimensional_input():
assert "size" in str(e) or "dimension" in str(e), f"Unexpected error: {e}"


@patch('nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils.is_fused_supported')
@patch('nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils.fft_causal_conv1d')
def test_fftconv_func_use_subquadratic_ops_success(mock_fft_causal_conv1d, mock_is_fused_supported):
def test_fftconv_func_use_subquadratic_ops_success(mock_fft_causal_conv1d):
"""Test fftconv_func with use_subquadratic_ops=True when supported."""
mock_is_fused_supported.return_value = True
mock_fft_causal_conv1d.return_value = torch.randn(2, 4, 8)

batch_size = 2
Expand All @@ -429,26 +427,6 @@ def test_fftconv_func_use_subquadratic_ops_success(mock_fft_causal_conv1d, mock_
mock_fft_causal_conv1d.assert_called_once()


@patch('nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils.is_fused_supported')
def test_fftconv_func_use_subquadratic_ops_not_supported(mock_is_fused_supported):
"""Test fftconv_func with use_subquadratic_ops=True when not supported."""
mock_is_fused_supported.return_value = False

batch_size = 2
seq_len = 8
hidden_size = 4

u = torch.randn(batch_size, hidden_size, seq_len)
k = torch.randn(hidden_size, seq_len)
D = torch.randn(hidden_size)
dropout_mask = torch.ones(batch_size, hidden_size)

with pytest.raises(
ValueError, match="subquadratic_ops FFT causal convolution is not supported for this filter length."
):
fftconv_func(u, k, D, dropout_mask, gelu=True, bidirectional=False, use_subquadratic_ops=True)


class TestFallbackFunctions:
"""Test the fallback functions that are defined when subquadratic_ops import fails."""

Expand Down Expand Up @@ -483,17 +461,6 @@ def test_fft_causal_conv1d_fallback(self, mock_fft_causal_conv1d):
with pytest.raises(ImportError, match="subquadratic_ops not installed. fft_causal_conv1d is not available."):
mock_fft_causal_conv1d(torch.randn(1, 1, 1), torch.randn(1, 1))

@patch('nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils.is_fused_supported')
def test_is_fused_supported_fallback(self, mock_is_fused_supported):
"""Test that the fallback is_fused_supported function raises ImportError."""
# Mock the function to raise ImportError
mock_is_fused_supported.side_effect = ImportError(
"subquadratic_ops not installed. is_fused_supported is not available."
)

with pytest.raises(ImportError, match="subquadratic_ops not installed. is_fused_supported is not available."):
mock_is_fused_supported(128)

def test_fallback_functions_import_error_messages(self):
"""Test that all fallback functions have consistent error messages."""
# Import the module to get access to the fallback functions
Expand All @@ -503,13 +470,11 @@ def test_fallback_functions_import_error_messages(self):
assert hasattr(hyena_utils, 'causal_conv1d')
assert hasattr(hyena_utils, 'b2b_causal_conv1d')
assert hasattr(hyena_utils, 'fft_causal_conv1d')
assert hasattr(hyena_utils, 'is_fused_supported')

# Test that they are callable
assert callable(hyena_utils.causal_conv1d)
assert callable(hyena_utils.b2b_causal_conv1d)
assert callable(hyena_utils.fft_causal_conv1d)
assert callable(hyena_utils.is_fused_supported)

def test_einops_import_error(self):
"""Test that the einops import error is raised with the correct message."""
Expand Down
Loading