diff --git a/pytorch_widedeep/models/tabular/resnet/tab_resnet.py b/pytorch_widedeep/models/tabular/resnet/tab_resnet.py index 22830ba..251eb5e 100644 --- a/pytorch_widedeep/models/tabular/resnet/tab_resnet.py +++ b/pytorch_widedeep/models/tabular/resnet/tab_resnet.py @@ -400,7 +400,7 @@ def __init__( if self.mlp_hidden_dims is not None: self.mlp = MLP( - d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim, + d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dims, activation=( "relu" if self.mlp_activation is None else self.mlp_activation ), diff --git a/pytorch_widedeep/models/tabular/transformers/ft_transformer.py b/pytorch_widedeep/models/tabular/transformers/ft_transformer.py index bca5db3..acb7297 100644 --- a/pytorch_widedeep/models/tabular/transformers/ft_transformer.py +++ b/pytorch_widedeep/models/tabular/transformers/ft_transformer.py @@ -307,7 +307,7 @@ def __init__( # therefore all related params are optional if self.mlp_hidden_dims is not None: self.mlp = MLP( - d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim, + d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dims, activation=( "relu" if self.mlp_activation is None else self.mlp_activation ), diff --git a/pytorch_widedeep/models/tabular/transformers/saint.py b/pytorch_widedeep/models/tabular/transformers/saint.py index 0e43c73..120eaeb 100644 --- a/pytorch_widedeep/models/tabular/transformers/saint.py +++ b/pytorch_widedeep/models/tabular/transformers/saint.py @@ -288,7 +288,7 @@ def __init__( # therefore all related params are optional if self.mlp_hidden_dims is not None: self.mlp = MLP( - d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim, + d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dims, activation=( "relu" if self.mlp_activation is None else self.mlp_activation ), diff --git a/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py b/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py index 2cfeb08..0e8235b 100644 --- a/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py +++ b/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py @@ -324,7 +324,7 @@ def __init__( # therefore all related params are optional if self.mlp_hidden_dims is not None: self.mlp = MLP( - d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim, + d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dims, activation=( "relu" if self.mlp_activation is None else self.mlp_activation ), diff --git a/pytorch_widedeep/models/tabular/transformers/tab_perceiver.py b/pytorch_widedeep/models/tabular/transformers/tab_perceiver.py index 78035a2..c09f224 100644 --- a/pytorch_widedeep/models/tabular/transformers/tab_perceiver.py +++ b/pytorch_widedeep/models/tabular/transformers/tab_perceiver.py @@ -321,7 +321,7 @@ def __init__( # Mlp if self.mlp_hidden_dims is not None: self.mlp = MLP( - d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim, + d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dims, activation=( "relu" if self.mlp_activation is None else self.mlp_activation ), diff --git a/pytorch_widedeep/models/tabular/transformers/tab_transformer.py b/pytorch_widedeep/models/tabular/transformers/tab_transformer.py index b6f0064..725c600 100644 --- a/pytorch_widedeep/models/tabular/transformers/tab_transformer.py +++ b/pytorch_widedeep/models/tabular/transformers/tab_transformer.py @@ -311,7 +311,7 @@ def __init__( if self.mlp_hidden_dims is not None: self.mlp = MLP( - d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim, + d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dims, activation=( "relu" if self.mlp_activation is None else self.mlp_activation ), diff --git a/requirements.txt b/requirements.txt index d952856..e4d2ada 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ pandas>=1.3.5 numpy>=1.21.6 -scipy>=1.7.3 +scipy>=1.7.3,<=1.12.0 scikit-learn>=1.0.2 gensim spacy diff --git a/tests/test_model_components/test_mc_transformers.py b/tests/test_model_components/test_mc_transformers.py index db58be5..5f5d702 100644 --- a/tests/test_model_components/test_mc_transformers.py +++ b/tests/test_model_components/test_mc_transformers.py @@ -208,6 +208,23 @@ def _build_model(model_name, params): return TabPerceiver(n_perceiver_blocks=2, n_latents=2, latent_dim=16, **params) +# test out of a bug related to the mlp_hidden_dims attribute (https://github.com/jrzaurin/pytorch-widedeep/issues/204) +@pytest.mark.parametrize( + "model_name", ["tabtransformer", "saint", "fttransformer", "tabfastformer"] +) +def test_mlphidden_dims(model_name): + params = { + "column_idx": {k: v for v, k in enumerate(colnames)}, + "cat_embed_input": embed_input, + "continuous_cols": colnames[n_cols:], + "mlp_hidden_dims": [32, 16], + } + + model = _build_model(model_name, params) + out = model(X_tab) + assert out.size(0) == 10 and out.size(1) == model.output_dim == 16 + + @pytest.mark.parametrize( "embed_continuous, with_cls_token, model_name", [