diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py index af400d34c93c..9e1920b586ad 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py @@ -66,7 +66,7 @@ def __getattribute__(self, name: str) -> None: logger.warning("WARNING: transformer_engine not installed. Using default recipe.") try: - from subquadratic_ops.rearrange import rearrange as subquadratic_ops_rearrange + from subquadratic_ops_torch.rearrange import rearrange as subquadratic_ops_rearrange except ImportError: def subquadratic_ops_rearrange(*args, **kwargs): diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py index 9d4aea70f0e5..4b3990cd8f64 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -54,11 +54,10 @@ def causal_conv1d_fn(*args, **kwargs): try: - from subquadratic_ops.b2b_causal_conv1d import b2b_causal_conv1d - from subquadratic_ops.causal_conv1d import causal_conv1d - from subquadratic_ops.fft_causal_conv1d import fft_causal_conv1d - from subquadratic_ops.fft_causal_conv1d import short_fft_is_available as is_fused_supported - from subquadratic_ops.implicit_filter import implicit_filter + from subquadratic_ops_torch.b2b_causal_conv1d import b2b_causal_conv1d + from subquadratic_ops_torch.causal_conv1d import causal_conv1d + from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d + from subquadratic_ops_torch.implicit_filter import implicit_filter except ImportError: def causal_conv1d(*args, **kwargs): @@ -73,10 +72,6 @@ def fft_causal_conv1d(*args, **kwargs): """Not imported: fft_causal_conv1d. An error will be raised if this is called.""" raise ImportError("subquadratic_ops not installed. fft_causal_conv1d is not available.") - def is_fused_supported(*args, **kwargs): - """Not imported: is_fused_supported. An error will be raised if this is called.""" - raise ImportError("subquadratic_ops not installed. is_fused_supported is not available.") - def implicit_filter(*args, **kwargs): """Not imported: implicit_filter. An error will be raised if this is called.""" raise ImportError("subquadratic_ops not installed. implicit_filter is not available.") @@ -482,8 +477,6 @@ def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=Fal # causal else: if use_subquadratic_ops: - if not is_fused_supported(k.shape[-1]): # TODO: Remove this check after full subquadratic_ops support - raise ValueError("subquadratic_ops FFT causal convolution is not supported for this filter length.") y = fft_causal_conv1d(u, k.squeeze(0)) else: k_f = torch.fft.rfft(k, n=fft_size) / fft_size @@ -1351,6 +1344,10 @@ class B2BCausalConv1dModule(nn.Module): """Module that performs back-to-back causal convolution operations using optimized CUDA kernels. Combines the projection and mixer convolutions into a single optimized operation. + + Note: This module stores references to other modules without registering them as child modules + to avoid duplicate parameters in the state dict. The actual parameters are owned by the + parent HyenaMixer module's hyena_proj_conv and mixer attributes. """ def __init__( @@ -1366,9 +1363,11 @@ def __init__( """ super().__init__() self.b2b_causal_conv1d_fn = b2b_causal_conv1d - # Store references to the modules, not their weights - self._proj_conv_module = proj_conv_module - self._mixer_module = mixer_module + # Store references to the modules WITHOUT registering them as child modules + # Using object.__setattr__ bypasses PyTorch's module registration system + # This prevents their parameters from appearing in the state dict with the b2b_kernel prefix + object.__setattr__(self, '_proj_conv_module', proj_conv_module) + object.__setattr__(self, '_mixer_module', mixer_module) self._use_conv_bias = self._mixer_module.use_conv_bias self.operator_type = operator_type diff --git a/tests/collections/llm/gpt/model/test_hyena_utils.py b/tests/collections/llm/gpt/model/test_hyena_utils.py index 97668334f797..14fa84d1bf59 100644 --- a/tests/collections/llm/gpt/model/test_hyena_utils.py +++ b/tests/collections/llm/gpt/model/test_hyena_utils.py @@ -407,11 +407,9 @@ def test_fftconv_func_high_dimensional_input(): assert "size" in str(e) or "dimension" in str(e), f"Unexpected error: {e}" -@patch('nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils.is_fused_supported') @patch('nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils.fft_causal_conv1d') -def test_fftconv_func_use_subquadratic_ops_success(mock_fft_causal_conv1d, mock_is_fused_supported): +def test_fftconv_func_use_subquadratic_ops_success(mock_fft_causal_conv1d): """Test fftconv_func with use_subquadratic_ops=True when supported.""" - mock_is_fused_supported.return_value = True mock_fft_causal_conv1d.return_value = torch.randn(2, 4, 8) batch_size = 2 @@ -429,26 +427,6 @@ def test_fftconv_func_use_subquadratic_ops_success(mock_fft_causal_conv1d, mock_ mock_fft_causal_conv1d.assert_called_once() -@patch('nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils.is_fused_supported') -def test_fftconv_func_use_subquadratic_ops_not_supported(mock_is_fused_supported): - """Test fftconv_func with use_subquadratic_ops=True when not supported.""" - mock_is_fused_supported.return_value = False - - batch_size = 2 - seq_len = 8 - hidden_size = 4 - - u = torch.randn(batch_size, hidden_size, seq_len) - k = torch.randn(hidden_size, seq_len) - D = torch.randn(hidden_size) - dropout_mask = torch.ones(batch_size, hidden_size) - - with pytest.raises( - ValueError, match="subquadratic_ops FFT causal convolution is not supported for this filter length." - ): - fftconv_func(u, k, D, dropout_mask, gelu=True, bidirectional=False, use_subquadratic_ops=True) - - class TestFallbackFunctions: """Test the fallback functions that are defined when subquadratic_ops import fails.""" @@ -483,17 +461,6 @@ def test_fft_causal_conv1d_fallback(self, mock_fft_causal_conv1d): with pytest.raises(ImportError, match="subquadratic_ops not installed. fft_causal_conv1d is not available."): mock_fft_causal_conv1d(torch.randn(1, 1, 1), torch.randn(1, 1)) - @patch('nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils.is_fused_supported') - def test_is_fused_supported_fallback(self, mock_is_fused_supported): - """Test that the fallback is_fused_supported function raises ImportError.""" - # Mock the function to raise ImportError - mock_is_fused_supported.side_effect = ImportError( - "subquadratic_ops not installed. is_fused_supported is not available." - ) - - with pytest.raises(ImportError, match="subquadratic_ops not installed. is_fused_supported is not available."): - mock_is_fused_supported(128) - def test_fallback_functions_import_error_messages(self): """Test that all fallback functions have consistent error messages.""" # Import the module to get access to the fallback functions @@ -503,13 +470,11 @@ def test_fallback_functions_import_error_messages(self): assert hasattr(hyena_utils, 'causal_conv1d') assert hasattr(hyena_utils, 'b2b_causal_conv1d') assert hasattr(hyena_utils, 'fft_causal_conv1d') - assert hasattr(hyena_utils, 'is_fused_supported') # Test that they are callable assert callable(hyena_utils.causal_conv1d) assert callable(hyena_utils.b2b_causal_conv1d) assert callable(hyena_utils.fft_causal_conv1d) - assert callable(hyena_utils.is_fused_supported) def test_einops_import_error(self): """Test that the einops import error is raised with the correct message."""