diff --git a/sbi/neural_nets/estimators/__init__.py b/sbi/neural_nets/estimators/__init__.py index a8792bd93..a885655ba 100644 --- a/sbi/neural_nets/estimators/__init__.py +++ b/sbi/neural_nets/estimators/__init__.py @@ -1,13 +1,11 @@ from sbi.neural_nets.estimators.base import ConditionalDensityEstimator from sbi.neural_nets.estimators.categorical_net import ( + CategoricalMADE, CategoricalMassEstimator, CategoricalNet, - CategoricalMADE, ) from sbi.neural_nets.estimators.flowmatching_estimator import FlowMatchingEstimator -from sbi.neural_nets.estimators.mixed_density_estimator import ( - MixedDensityEstimator, -) +from sbi.neural_nets.estimators.mixed_density_estimator import MixedDensityEstimator from sbi.neural_nets.estimators.nflows_flow import NFlowsFlow from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator from sbi.neural_nets.estimators.zuko_flow import ZukoFlow diff --git a/sbi/neural_nets/estimators/categorical_net.py b/sbi/neural_nets/estimators/categorical_net.py index 8e79e7072..b10cff021 100644 --- a/sbi/neural_nets/estimators/categorical_net.py +++ b/sbi/neural_nets/estimators/categorical_net.py @@ -1,7 +1,7 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see -from typing import Optional +from typing import Callable, Optional import torch from nflows.nn.nde.made import MADE @@ -15,29 +15,46 @@ class CategoricalMADE(MADE): + """Conditional density (mass) estimation for a n-dim categorical random variable. + + Takes as input parameters theta and learns the parameters p of a Categorical. + + Defines log prob and sample functions. + """ + def __init__( self, - categories, # Tensor[int] - hidden_features, - context_features=None, - num_blocks=2, - use_residual_blocks=True, - random_mask=False, - activation=F.relu, - dropout_probability=0.0, - use_batch_norm=False, - epsilon=1e-2, - custom_initialization=True, + num_categories: Tensor, # Tensor[int] + hidden_features: int, + context_features: Optional[int] = None, + num_blocks: int = 2, + use_residual_blocks: bool = True, + random_mask: bool = False, + activation: Callable = F.relu, + dropout_probability: float = 0.0, + use_batch_norm: bool = False, + epsilon: float = 1e-2, + custom_initialization: bool = True, embedding_net: Optional[nn.Module] = nn.Identity(), ): + """Initialize the neural net. + + Args: + num_categories: number of categories for each variable. len(categories) + defines the number of input units, i.e., dimensionality of the features. + max(categories) defines the number of output units, i.e., the largest + number of categories. + num_hidden: number of hidden units per layer. + num_layers: number of hidden layers. + embedding_net: emebedding net for input. + """ if use_residual_blocks and random_mask: raise ValueError("Residual blocks can't be used with random masks.") - self.num_variables = len(categories) - self.num_categories = int(max(categories)) - self.categories = categories + self.num_variables = len(num_categories) + self.num_categories = int(torch.max(num_categories)) self.mask = torch.zeros(self.num_variables, self.num_categories) - for i, c in enumerate(categories): + for i, c in enumerate(num_categories): self.mask[i, :c] = 1 super().__init__( @@ -60,7 +77,18 @@ def __init__( if custom_initialization: self._initialize() - def forward(self, inputs, context=None): + def forward(self, inputs: Tensor, context: Optional[Tensor] = None) -> Tensor: + r"""Forward pass of the categorical density estimator network to compute the + conditional density at a given time. + + Args: + input: Original data, x0. (batch_size, *input_shape) + condition: Conditioning variable. (batch_size, *condition_shape) + + Returns: + Predicted categorical probabilities. (batch_size, *input_shape, + num_categories) + """ embedded_context = self.embedding_net.forward(context) return super().forward(inputs, context=embedded_context) @@ -69,8 +97,16 @@ def compute_probs(self, outputs): ps = ps / ps.sum(dim=-1, keepdim=True) return ps - # outputs (batch_size, num_variables, num_categories) - def log_prob(self, inputs, context=None): + def log_prob(self, inputs: Tensor, context: Optional[Tensor] = None) -> Tensor: + r"""Return log-probability of samples. + + Args: + input: Input datapoints of shape `(batch_size, *input_shape)`. + context: Context of shape `(batch_size, *condition_shape)`. + + Returns: + Log-probabilities of shape `(batch_size, num_variables, num_categories)`. + """ outputs = self.forward(inputs, context=context) outputs = outputs.reshape(*inputs.shape, self.num_categories) ps = self.compute_probs(outputs) diff --git a/sbi/neural_nets/estimators/mixed_density_estimator.py b/sbi/neural_nets/estimators/mixed_density_estimator.py index 27cc2d2b6..e46d7baeb 100644 --- a/sbi/neural_nets/estimators/mixed_density_estimator.py +++ b/sbi/neural_nets/estimators/mixed_density_estimator.py @@ -147,8 +147,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: f"{input_batch_dim} do not match." ) - num_disc = self.discrete_net.net.num_variables - cont_input, disc_input = _separate_input(input, num_discrete_columns=num_disc) + num_discrete_variables = self.discrete_net.net.num_variables + cont_input, disc_input = _separate_input(input, num_discrete_variables) # Embed continuous condition embedded_condition = self.condition_embedding(condition) # expand and repeat to match batch of inputs. diff --git a/sbi/neural_nets/net_builders/categorial.py b/sbi/neural_nets/net_builders/categorial.py index adf59cecf..0c98892b9 100644 --- a/sbi/neural_nets/net_builders/categorial.py +++ b/sbi/neural_nets/net_builders/categorial.py @@ -73,7 +73,7 @@ def build_autoregressive_categoricalmassestimator( z_score_y: Optional[str] = "independent", num_hidden: int = 20, num_layers: int = 2, - categories: Optional[Tensor] = None, + num_categories: Optional[Tensor] = None, embedding_net: nn.Module = nn.Identity(), ): """Returns a density estimator for a categorical random variable. @@ -86,13 +86,14 @@ def build_autoregressive_categoricalmassestimator( num_hidden: Number of hidden units per layer. num_layers: Number of hidden layers. embedding_net: Embedding net for y. + num_categories: number of categories for each variable. """ if z_score_x != "none": raise ValueError("Categorical input should not be z-scored.") - if categories is None: + if num_categories is None: warnings.warn( - "Inferring categories from batch_x. Ensure all categories are present.", + "Inferring num_categories from batch_x. Ensure all categories are present.", stacklevel=2, ) @@ -108,10 +109,12 @@ def build_autoregressive_categoricalmassestimator( batch_x_discrete = batch_x[:, _is_discrete(batch_x)] inferred_categories = tensor([unique(col).numel() for col in batch_x_discrete.T]) - categories = categories if categories is not None else inferred_categories + num_categories = ( + num_categories if num_categories is not None else inferred_categories + ) categorical_net = CategoricalMADE( - categories=categories, + num_categories=num_categories, hidden_features=num_hidden, context_features=y_numel, num_blocks=num_layers, diff --git a/sbi/neural_nets/net_builders/mnle.py b/sbi/neural_nets/net_builders/mnle.py index 00801e4c6..391abce8c 100644 --- a/sbi/neural_nets/net_builders/mnle.py +++ b/sbi/neural_nets/net_builders/mnle.py @@ -60,6 +60,7 @@ def build_mnle( z_score_y: Optional[str] = "independent", flow_model: str = "nsf", categorical_model: str = "mlp", + num_categorical_columns: Optional[Tensor] = None, embedding_net: nn.Module = nn.Identity(), combined_embedding_net: Optional[nn.Module] = None, num_transforms: int = 2, @@ -108,6 +109,8 @@ def build_mnle( data. categorical_model: type of categorical net to use for the discrete part of the data. Can be "made" or "mlp". + num_categorical_columns: Number of categorical columns of each variable in the + input data. If None, the function will infer this from the data. embedding_net: Optional embedding network for y, required if y is > 1D. combined_embedding_net: Optional embedding for combining the discrete part of the input and the embedded condition into a joined @@ -137,7 +140,10 @@ def build_mnle( stacklevel=2, ) # Separate continuous and discrete data. - num_disc = int(torch.sum(_is_discrete(batch_x))) + if num_categorical_columns is None: + num_disc = int(torch.sum(_is_discrete(batch_x))) + else: + num_disc = len(num_categorical_columns) cont_x, disc_x = _separate_input(batch_x, num_discrete_columns=num_disc) # Set up y-embedding net with z-scoring. @@ -160,6 +166,7 @@ def build_mnle( num_hidden=hidden_features, num_layers=hidden_layers, embedding_net=embedding_net, + num_categories=num_categorical_columns, ) elif categorical_model == "mlp": assert num_disc == 1, "MLP only supports 1D input."