Skip to content

Commit

Permalink
fix: address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Nov 14, 2024
1 parent bcc75db commit 0188cea
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 31 deletions.
6 changes: 2 additions & 4 deletions sbi/neural_nets/estimators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
from sbi.neural_nets.estimators.categorical_net import (
CategoricalMADE,
CategoricalMassEstimator,
CategoricalNet,
CategoricalMADE,
)
from sbi.neural_nets.estimators.flowmatching_estimator import FlowMatchingEstimator
from sbi.neural_nets.estimators.mixed_density_estimator import (
MixedDensityEstimator,
)
from sbi.neural_nets.estimators.mixed_density_estimator import MixedDensityEstimator
from sbi.neural_nets.estimators.nflows_flow import NFlowsFlow
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.neural_nets.estimators.zuko_flow import ZukoFlow
74 changes: 55 additions & 19 deletions sbi/neural_nets/estimators/categorical_net.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Optional
from typing import Callable, Optional

import torch
from nflows.nn.nde.made import MADE
Expand All @@ -15,29 +15,46 @@


class CategoricalMADE(MADE):
"""Conditional density (mass) estimation for a n-dim categorical random variable.
Takes as input parameters theta and learns the parameters p of a Categorical.
Defines log prob and sample functions.
"""

def __init__(
self,
categories, # Tensor[int]
hidden_features,
context_features=None,
num_blocks=2,
use_residual_blocks=True,
random_mask=False,
activation=F.relu,
dropout_probability=0.0,
use_batch_norm=False,
epsilon=1e-2,
custom_initialization=True,
num_categories: Tensor, # Tensor[int]
hidden_features: int,
context_features: Optional[int] = None,
num_blocks: int = 2,
use_residual_blocks: bool = True,
random_mask: bool = False,
activation: Callable = F.relu,
dropout_probability: float = 0.0,
use_batch_norm: bool = False,
epsilon: float = 1e-2,
custom_initialization: bool = True,
embedding_net: Optional[nn.Module] = nn.Identity(),
):
"""Initialize the neural net.
Args:
num_categories: number of categories for each variable. len(categories)
defines the number of input units, i.e., dimensionality of the features.
max(categories) defines the number of output units, i.e., the largest
number of categories.
num_hidden: number of hidden units per layer.
num_layers: number of hidden layers.
embedding_net: emebedding net for input.
"""
if use_residual_blocks and random_mask:
raise ValueError("Residual blocks can't be used with random masks.")

Check warning on line 52 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L51-L52

Added lines #L51 - L52 were not covered by tests

self.num_variables = len(categories)
self.num_categories = int(max(categories))
self.categories = categories
self.num_variables = len(num_categories)
self.num_categories = int(torch.max(num_categories))
self.mask = torch.zeros(self.num_variables, self.num_categories)
for i, c in enumerate(categories):
for i, c in enumerate(num_categories):
self.mask[i, :c] = 1

Check warning on line 58 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L54-L58

Added lines #L54 - L58 were not covered by tests

super().__init__(

Check warning on line 60 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L60

Added line #L60 was not covered by tests
Expand All @@ -60,7 +77,18 @@ def __init__(
if custom_initialization:
self._initialize()

Check warning on line 78 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L77-L78

Added lines #L77 - L78 were not covered by tests

def forward(self, inputs, context=None):
def forward(self, inputs: Tensor, context: Optional[Tensor] = None) -> Tensor:
r"""Forward pass of the categorical density estimator network to compute the
conditional density at a given time.
Args:
input: Original data, x0. (batch_size, *input_shape)
condition: Conditioning variable. (batch_size, *condition_shape)
Returns:
Predicted categorical probabilities. (batch_size, *input_shape,
num_categories)
"""
embedded_context = self.embedding_net.forward(context)
return super().forward(inputs, context=embedded_context)

Check warning on line 93 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L92-L93

Added lines #L92 - L93 were not covered by tests

Expand All @@ -69,8 +97,16 @@ def compute_probs(self, outputs):
ps = ps / ps.sum(dim=-1, keepdim=True)
return ps

Check warning on line 98 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L96-L98

Added lines #L96 - L98 were not covered by tests

# outputs (batch_size, num_variables, num_categories)
def log_prob(self, inputs, context=None):
def log_prob(self, inputs: Tensor, context: Optional[Tensor] = None) -> Tensor:
r"""Return log-probability of samples.
Args:
input: Input datapoints of shape `(batch_size, *input_shape)`.
context: Context of shape `(batch_size, *condition_shape)`.
Returns:
Log-probabilities of shape `(batch_size, num_variables, num_categories)`.
"""
outputs = self.forward(inputs, context=context)
outputs = outputs.reshape(*inputs.shape, self.num_categories)
ps = self.compute_probs(outputs)

Check warning on line 112 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L110-L112

Added lines #L110 - L112 were not covered by tests
Expand Down
4 changes: 2 additions & 2 deletions sbi/neural_nets/estimators/mixed_density_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
f"{input_batch_dim} do not match."
)

num_disc = self.discrete_net.net.num_variables
cont_input, disc_input = _separate_input(input, num_discrete_columns=num_disc)
num_discrete_variables = self.discrete_net.net.num_variables
cont_input, disc_input = _separate_input(input, num_discrete_variables)
# Embed continuous condition
embedded_condition = self.condition_embedding(condition)
# expand and repeat to match batch of inputs.
Expand Down
13 changes: 8 additions & 5 deletions sbi/neural_nets/net_builders/categorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def build_autoregressive_categoricalmassestimator(
z_score_y: Optional[str] = "independent",
num_hidden: int = 20,
num_layers: int = 2,
categories: Optional[Tensor] = None,
num_categories: Optional[Tensor] = None,
embedding_net: nn.Module = nn.Identity(),
):
"""Returns a density estimator for a categorical random variable.
Expand All @@ -86,13 +86,14 @@ def build_autoregressive_categoricalmassestimator(
num_hidden: Number of hidden units per layer.
num_layers: Number of hidden layers.
embedding_net: Embedding net for y.
num_categories: number of categories for each variable.
"""

if z_score_x != "none":
raise ValueError("Categorical input should not be z-scored.")
if categories is None:
if num_categories is None:
warnings.warn(

Check warning on line 95 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L92-L95

Added lines #L92 - L95 were not covered by tests
"Inferring categories from batch_x. Ensure all categories are present.",
"Inferring num_categories from batch_x. Ensure all categories are present.",
stacklevel=2,
)

Expand All @@ -108,10 +109,12 @@ def build_autoregressive_categoricalmassestimator(

batch_x_discrete = batch_x[:, _is_discrete(batch_x)]
inferred_categories = tensor([unique(col).numel() for col in batch_x_discrete.T])
categories = categories if categories is not None else inferred_categories
num_categories = (

Check warning on line 112 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L110-L112

Added lines #L110 - L112 were not covered by tests
num_categories if num_categories is not None else inferred_categories
)

categorical_net = CategoricalMADE(

Check warning on line 116 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L116

Added line #L116 was not covered by tests
categories=categories,
num_categories=num_categories,
hidden_features=num_hidden,
context_features=y_numel,
num_blocks=num_layers,
Expand Down
9 changes: 8 additions & 1 deletion sbi/neural_nets/net_builders/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def build_mnle(
z_score_y: Optional[str] = "independent",
flow_model: str = "nsf",
categorical_model: str = "mlp",
num_categorical_columns: Optional[Tensor] = None,
embedding_net: nn.Module = nn.Identity(),
combined_embedding_net: Optional[nn.Module] = None,
num_transforms: int = 2,
Expand Down Expand Up @@ -108,6 +109,8 @@ def build_mnle(
data.
categorical_model: type of categorical net to use for the discrete part of
the data. Can be "made" or "mlp".
num_categorical_columns: Number of categorical columns of each variable in the
input data. If None, the function will infer this from the data.
embedding_net: Optional embedding network for y, required if y is > 1D.
combined_embedding_net: Optional embedding for combining the discrete
part of the input and the embedded condition into a joined
Expand Down Expand Up @@ -137,7 +140,10 @@ def build_mnle(
stacklevel=2,
)
# Separate continuous and discrete data.
num_disc = int(torch.sum(_is_discrete(batch_x)))
if num_categorical_columns is None:
num_disc = int(torch.sum(_is_discrete(batch_x)))
else:
num_disc = len(num_categorical_columns)

Check warning on line 146 in sbi/neural_nets/net_builders/mnle.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/mnle.py#L146

Added line #L146 was not covered by tests
cont_x, disc_x = _separate_input(batch_x, num_discrete_columns=num_disc)

# Set up y-embedding net with z-scoring.
Expand All @@ -160,6 +166,7 @@ def build_mnle(
num_hidden=hidden_features,
num_layers=hidden_layers,
embedding_net=embedding_net,
num_categories=num_categorical_columns,
)
elif categorical_model == "mlp":
assert num_disc == 1, "MLP only supports 1D input."
Expand Down

0 comments on commit 0188cea

Please sign in to comment.