Skip to content

Commit

Permalink
fix: change net kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Sep 13, 2024
1 parent d070d2a commit 089d1d3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions sbi/neural_nets/net_builders/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def build_mnle(
z_score_x: Optional[str] = "independent",
z_score_y: Optional[str] = "independent",
flow_model: str = "nsf",
categorical_model: str = "made",
categorical_model: str = "mlp",
embedding_net: nn.Module = nn.Identity(),
combined_embedding_net: Optional[nn.Module] = None,
num_transforms: int = 2,
Expand Down Expand Up @@ -104,7 +104,7 @@ def build_mnle(
flow_model: type of flow model to use for the continuous part of the
data.
categorical_model: type of categorical net to use for the discrete part of
the data. Can be "made" or "categorical".
the data. Can be "made" or "mlp".
embedding_net: Optional embedding network for y, required if y is > 1D.
combined_embedding_net: Optional embedding for combining the discrete
part of the input and the embedded condition into a joined
Expand Down Expand Up @@ -157,7 +157,7 @@ def build_mnle(
num_layers=hidden_layers,
embedding_net=embedding_net,
)
elif categorical_model == "categorical":
elif categorical_model == "mlp":
discrete_net = build_categoricalmassestimator(
disc_x,
batch_y,
Expand All @@ -169,7 +169,7 @@ def build_mnle(
)
else:
raise ValueError(

Check warning on line 171 in sbi/neural_nets/net_builders/mnle.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/mnle.py#L171

Added line #L171 was not covered by tests
f"Unknown categorical net {categorical_model}. Must be 'made' or 'categorical'."
f"Unknown categorical net {categorical_model}. Must be 'made' or 'mlp'."
)

if combined_embedding_net is None:
Expand Down

0 comments on commit 089d1d3

Please sign in to comment.