Skip to content

Commit

Permalink
Merge pull request #205 from jrzaurin/fix_mlp_hidden_dim_attribute_bug
Browse files Browse the repository at this point in the history
fixed an issue related to the mlp_hidden_dims param/attribute #204
  • Loading branch information
jrzaurin committed Apr 4, 2024
2 parents b1bf2fa + 887265e commit 79a33f1
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pytorch_widedeep/models/tabular/resnet/tab_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def __init__(

if self.mlp_hidden_dims is not None:
self.mlp = MLP(
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim,
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dims,
activation=(
"relu" if self.mlp_activation is None else self.mlp_activation
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def __init__(
# therefore all related params are optional
if self.mlp_hidden_dims is not None:
self.mlp = MLP(
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim,
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dims,
activation=(
"relu" if self.mlp_activation is None else self.mlp_activation
),
Expand Down
2 changes: 1 addition & 1 deletion pytorch_widedeep/models/tabular/transformers/saint.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def __init__(
# therefore all related params are optional
if self.mlp_hidden_dims is not None:
self.mlp = MLP(
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim,
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dims,
activation=(
"relu" if self.mlp_activation is None else self.mlp_activation
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def __init__(
# therefore all related params are optional
if self.mlp_hidden_dims is not None:
self.mlp = MLP(
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim,
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dims,
activation=(
"relu" if self.mlp_activation is None else self.mlp_activation
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def __init__(
# Mlp
if self.mlp_hidden_dims is not None:
self.mlp = MLP(
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim,
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dims,
activation=(
"relu" if self.mlp_activation is None else self.mlp_activation
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def __init__(

if self.mlp_hidden_dims is not None:
self.mlp = MLP(
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim,
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dims,
activation=(
"relu" if self.mlp_activation is None else self.mlp_activation
),
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pandas>=1.3.5
numpy>=1.21.6
scipy>=1.7.3
scipy>=1.7.3,<=1.12.0
scikit-learn>=1.0.2
gensim
spacy
Expand Down
17 changes: 17 additions & 0 deletions tests/test_model_components/test_mc_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,23 @@ def _build_model(model_name, params):
return TabPerceiver(n_perceiver_blocks=2, n_latents=2, latent_dim=16, **params)


# test out of a bug related to the mlp_hidden_dims attribute (https://github.com/jrzaurin/pytorch-widedeep/issues/204)
@pytest.mark.parametrize(
"model_name", ["tabtransformer", "saint", "fttransformer", "tabfastformer"]
)
def test_mlphidden_dims(model_name):
params = {
"column_idx": {k: v for v, k in enumerate(colnames)},
"cat_embed_input": embed_input,
"continuous_cols": colnames[n_cols:],
"mlp_hidden_dims": [32, 16],
}

model = _build_model(model_name, params)
out = model(X_tab)
assert out.size(0) == 10 and out.size(1) == model.output_dim == 16


@pytest.mark.parametrize(
"embed_continuous, with_cls_token, model_name",
[
Expand Down

0 comments on commit 79a33f1

Please sign in to comment.