diff --git a/sbi/neural_nets/estimators/categorical_net.py b/sbi/neural_nets/estimators/categorical_net.py index 2e8ce49d8..77a690147 100644 --- a/sbi/neural_nets/estimators/categorical_net.py +++ b/sbi/neural_nets/estimators/categorical_net.py @@ -4,21 +4,20 @@ from typing import Optional import torch -from torch import Tensor, nn, distributions +from nflows.nn.nde.made import MADE +from nflows.utils import torchutils +from torch import Tensor, nn from torch.distributions import Categorical from torch.nn import Sigmoid, Softmax +from torch.nn import functional as F from sbi.neural_nets.estimators.base import ConditionalDensityEstimator -from nflows.nn.nde.made import MADE -from torch.nn import functional as F -from nflows.utils import torchutils -import numpy as np class CategoricalMADE(MADE): def __init__( self, - categories, # List[int] or Tensor[int] + categories, # Tensor[int] hidden_features, context_features=None, num_blocks=2, @@ -31,7 +30,6 @@ def __init__( custom_initialization=True, embedding_net: Optional[nn.Module] = nn.Identity(), ): - if use_residual_blocks and random_mask: raise ValueError("Residual blocks can't be used with random masks.") @@ -65,22 +63,22 @@ def __init__( def forward(self, inputs, context=None): embedded_inputs = self.embedding_net.forward(inputs) return super().forward(embedded_inputs, context=context) - + def compute_probs(self, outputs): - ps = F.softmax(outputs, dim=-1)*self.mask + ps = F.softmax(outputs, dim=-1) * self.mask ps = ps / ps.sum(dim=-1, keepdim=True) - return ps.squeeze(-2) - + return ps + # outputs (batch_size, num_variables, num_categories) def log_prob(self, inputs, context=None): outputs = self.forward(inputs, context=context) outputs = outputs.reshape(*inputs.shape, self.num_categories) ps = self.compute_probs(outputs) - + # categorical log prob - log_prob = torch.log(ps.gather(-1, inputs.long())) - log_prob = log_prob.sum(dim=-1) - + log_prob = torch.log(ps.gather(-1, inputs.unsqueeze(-1).long())) + log_prob = log_prob.squeeze(-1).sum(dim=-1) + return log_prob def sample(self, sample_shape, context=None): @@ -88,10 +86,10 @@ def sample(self, sample_shape, context=None): if isinstance(sample_shape, int): sample_shape = (sample_shape,) sample_shape = torch.Size(sample_shape) - + # Calculate total number of samples num_samples = torch.prod(torch.tensor(sample_shape)).item() - + # Prepare context if context is not None: if context.ndim == 1: @@ -99,20 +97,23 @@ def sample(self, sample_shape, context=None): context = torchutils.repeat_rows(context, num_samples) else: context = torch.zeros(num_samples, self.context_dim) - + with torch.no_grad(): samples = torch.zeros(num_samples, self.num_variables) for variable in range(self.num_variables): outputs = self.forward(samples, context) - outputs = outputs.reshape(num_samples, self.num_variables, self.num_categories) + outputs = outputs.reshape( + num_samples, self.num_variables, self.num_categories + ) ps = self.compute_probs(outputs) - samples[:, variable] = Categorical(probs=ps[:,variable]).sample() - + samples[:, variable] = Categorical(probs=ps[:, variable]).sample() + return samples.reshape(*sample_shape, self.num_variables) def _initialize(self): pass + class CategoricalNet(nn.Module): """Conditional density (mass) estimation for a categorical random variable. @@ -145,6 +146,7 @@ def __init__( self.activation = Sigmoid() self.softmax = Softmax(dim=1) self.num_categories = num_categories + self.num_variables = 1 # Maybe add embedding net in front. if embedding_net is not None: diff --git a/sbi/neural_nets/estimators/mixed_density_estimator.py b/sbi/neural_nets/estimators/mixed_density_estimator.py index dedba1b52..27cc2d2b6 100644 --- a/sbi/neural_nets/estimators/mixed_density_estimator.py +++ b/sbi/neural_nets/estimators/mixed_density_estimator.py @@ -80,8 +80,10 @@ def sample( sample_shape=sample_shape, condition=condition, ) - # Trailing `1` because `Categorical` has event_shape `()`. - discrete_samples = discrete_samples.reshape(num_samples * batch_dim, 1) + num_variables = self.discrete_net.net.num_variables + discrete_samples = discrete_samples.reshape( + num_samples * batch_dim, num_variables + ) # repeat the batch of embedded condition to match number of choices. condition_event_dim = embedded_condition.dim() - 1 @@ -145,7 +147,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: f"{input_batch_dim} do not match." ) - cont_input, disc_input = _separate_input(input) + num_disc = self.discrete_net.net.num_variables + cont_input, disc_input = _separate_input(input, num_discrete_columns=num_disc) # Embed continuous condition embedded_condition = self.condition_embedding(condition) # expand and repeat to match batch of inputs. @@ -204,3 +207,8 @@ def _separate_input( Assumes the discrete data to live in the last columns of input. """ return input[..., :-num_discrete_columns], input[..., -num_discrete_columns:] + + +def _is_discrete(input: Tensor) -> Tensor: + """Infer discrete columns in input data.""" + return torch.tensor([torch.allclose(col, col.round()) for col in input.T]) diff --git a/sbi/neural_nets/net_builders/categorial.py b/sbi/neural_nets/net_builders/categorial.py index c8662dd64..adf59cecf 100644 --- a/sbi/neural_nets/net_builders/categorial.py +++ b/sbi/neural_nets/net_builders/categorial.py @@ -1,16 +1,19 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see +import warnings from typing import Optional -from torch import Tensor, nn, unique, tensor +from torch import Tensor, nn, tensor, unique -from sbi.neural_nets.estimators import CategoricalMassEstimator, CategoricalNet, CategoricalMADE -from sbi.utils.nn_utils import get_numel -from sbi.utils.sbiutils import ( - standardizing_net, - z_score_parser, +from sbi.neural_nets.estimators import ( + CategoricalMADE, + CategoricalMassEstimator, + CategoricalNet, ) +from sbi.neural_nets.estimators.mixed_density_estimator import _is_discrete +from sbi.utils.nn_utils import get_numel +from sbi.utils.sbiutils import standardizing_net, z_score_parser from sbi.utils.user_input_checks import check_data_device @@ -70,7 +73,7 @@ def build_autoregressive_categoricalmassestimator( z_score_y: Optional[str] = "independent", num_hidden: int = 20, num_layers: int = 2, - num_variables: int = 1, + categories: Optional[Tensor] = None, embedding_net: nn.Module = nn.Identity(), ): """Returns a density estimator for a categorical random variable. @@ -87,6 +90,11 @@ def build_autoregressive_categoricalmassestimator( if z_score_x != "none": raise ValueError("Categorical input should not be z-scored.") + if categories is None: + warnings.warn( + "Inferring categories from batch_x. Ensure all categories are present.", + stacklevel=2, + ) check_data_device(batch_x, batch_y) @@ -98,17 +106,18 @@ def build_autoregressive_categoricalmassestimator( standardizing_net(batch_y, structured_y), embedding_net ) - categories = tensor([unique(variable).numel() for variable in batch_x.T]) - categories = categories[-num_variables:] + batch_x_discrete = batch_x[:, _is_discrete(batch_x)] + inferred_categories = tensor([unique(col).numel() for col in batch_x_discrete.T]) + categories = categories if categories is not None else inferred_categories categorical_net = CategoricalMADE( categories=categories, hidden_features=num_hidden, context_features=y_numel, num_blocks=num_layers, - # TODO: embedding_net=embedding_net, + embedding_net=embedding_net, ) return CategoricalMassEstimator( categorical_net, input_shape=batch_x[0].shape, condition_shape=batch_y[0].shape - ) \ No newline at end of file + ) diff --git a/sbi/neural_nets/net_builders/mnle.py b/sbi/neural_nets/net_builders/mnle.py index 03c151daf..00801e4c6 100644 --- a/sbi/neural_nets/net_builders/mnle.py +++ b/sbi/neural_nets/net_builders/mnle.py @@ -8,8 +8,14 @@ from torch import Tensor, nn from sbi.neural_nets.estimators import MixedDensityEstimator -from sbi.neural_nets.estimators.mixed_density_estimator import _separate_input -from sbi.neural_nets.net_builders.categorial import build_categoricalmassestimator, build_autoregressive_categoricalmassestimator +from sbi.neural_nets.estimators.mixed_density_estimator import ( + _is_discrete, + _separate_input, +) +from sbi.neural_nets.net_builders.categorial import ( + build_autoregressive_categoricalmassestimator, + build_categoricalmassestimator, +) from sbi.neural_nets.net_builders.flow import ( build_made, build_maf, @@ -26,10 +32,7 @@ build_zuko_unaf, ) from sbi.neural_nets.net_builders.mdn import build_mdn -from sbi.utils.sbiutils import ( - standardizing_net, - z_score_parser, -) +from sbi.utils.sbiutils import standardizing_net, z_score_parser from sbi.utils.user_input_checks import check_data_device model_builders = { @@ -128,13 +131,14 @@ def build_mnle( warnings.warn( "The mixed neural likelihood estimator assumes that x contains " - "continuous data in the first n-1 columns (e.g., reaction times) and " - "categorical data in the last column (e.g., corresponding choices). If " + "continuous data in the first n-k columns (e.g., reaction times) and " + "categorical data in the last k columns (e.g., corresponding choices). If " "this is not the case for the passed `x` do not use this function.", stacklevel=2, ) # Separate continuous and discrete data. - cont_x, disc_x = _separate_input(batch_x) + num_disc = int(torch.sum(_is_discrete(batch_x))) + cont_x, disc_x = _separate_input(batch_x, num_discrete_columns=num_disc) # Set up y-embedding net with z-scoring. z_score_y_bool, structured_y = z_score_parser(z_score_y) @@ -158,6 +162,7 @@ def build_mnle( embedding_net=embedding_net, ) elif categorical_model == "mlp": + assert num_disc == 1, "MLP only supports 1D input." discrete_net = build_categoricalmassestimator( disc_x, batch_y,