Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CategoricalMADE #1269

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions sbi/neural_nets/estimators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
from sbi.neural_nets.estimators.categorical_net import (
CategoricalMADE,
CategoricalMassEstimator,
CategoricalNet,
)
from sbi.neural_nets.estimators.flowmatching_estimator import FlowMatchingEstimator
from sbi.neural_nets.estimators.mixed_density_estimator import (
MixedDensityEstimator,
)
from sbi.neural_nets.estimators.mixed_density_estimator import MixedDensityEstimator
from sbi.neural_nets.estimators.nflows_flow import NFlowsFlow
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.neural_nets.estimators.zuko_flow import ZukoFlow
165 changes: 96 additions & 69 deletions sbi/neural_nets/estimators/categorical_net.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Optional
from typing import Callable, Optional

import torch
from nflows.nn.nde.made import MADE
from nflows.utils import torchutils
from torch import Tensor, nn
from torch.distributions import Categorical
from torch.nn import Sigmoid, Softmax
from torch.nn import functional as F

from sbi.neural_nets.estimators.base import ConditionalDensityEstimator


class CategoricalNet(nn.Module):
"""Conditional density (mass) estimation for a categorical random variable.
class CategoricalMADE(MADE):
"""Conditional density (mass) estimation for a n-dim categorical random variable.

Takes as input parameters theta and learns the parameters p of a Categorical.

Expand All @@ -21,93 +23,118 @@

def __init__(
self,
num_input: int,
num_categories: int,
num_hidden: int = 20,
num_layers: int = 2,
embedding_net: Optional[nn.Module] = None,
num_categories: Tensor, # Tensor[int]
hidden_features: int,
context_features: Optional[int] = None,
num_blocks: int = 2,
use_residual_blocks: bool = True,
random_mask: bool = False,
activation: Callable = F.relu,
dropout_probability: float = 0.0,
use_batch_norm: bool = False,
epsilon: float = 1e-2,
embedding_net: nn.Module = nn.Identity(),
):
jnsbck marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize the neural net.

Args:
num_input: number of input units, i.e., dimensionality of the features.
num_categories: number of output units, i.e., number of categories.
num_categories: number of categories for each variable. len(categories)
defines the number of input units, i.e., dimensionality of the features.
max(categories) defines the number of output units, i.e., the largest
number of categories.
num_hidden: number of hidden units per layer.
num_layers: number of hidden layers.
embedding_net: emebedding net for input.
"""
super().__init__()

self.num_hidden = num_hidden
self.num_input = num_input
self.activation = Sigmoid()
self.softmax = Softmax(dim=1)
self.num_categories = num_categories

# Maybe add embedding net in front.
if embedding_net is not None:
self.input_layer = nn.Sequential(
embedding_net, nn.Linear(num_input, num_hidden)
)
else:
self.input_layer = nn.Linear(num_input, num_hidden)

# Repeat hidden units hidden layers times.
self.hidden_layers = nn.ModuleList()
for _ in range(num_layers):
self.hidden_layers.append(nn.Linear(num_hidden, num_hidden))

self.output_layer = nn.Linear(num_hidden, num_categories)
if use_residual_blocks and random_mask:
raise ValueError("Residual blocks can't be used with random masks.")

Check warning on line 50 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L50

Added line #L50 was not covered by tests

def forward(self, condition: Tensor) -> Tensor:
"""Return categorical probability predicted from a batch of inputs.
self.num_variables = len(num_categories)
self.num_categories = int(torch.max(num_categories))
self.mask = torch.zeros(self.num_variables, self.num_categories)
for i, c in enumerate(num_categories):
self.mask[i, :c] = 1

Args:
condition: batch of context parameters for the net.

Returns:
Tensor: batch of predicted categorical probabilities.
"""
# forward path
condition = self.activation(self.input_layer(condition))

# iterate n hidden layers, input condition and calculate tanh activation
for layer in self.hidden_layers:
condition = self.activation(layer(condition))
super().__init__(
self.num_variables,
hidden_features,
context_features=context_features,
num_blocks=num_blocks,
output_multiplier=self.num_categories,
use_residual_blocks=use_residual_blocks,
random_mask=random_mask,
activation=activation,
dropout_probability=dropout_probability,
use_batch_norm=use_batch_norm,
)

return self.softmax(self.output_layer(condition))
self.embedding_net = embedding_net
self.hidden_features = hidden_features
self.epsilon = epsilon
self.context_features = context_features

def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
"""Return categorical log probability of categories input, given condition.
def forward(self, input: Tensor, condition: Optional[Tensor] = None) -> Tensor:
r"""Forward pass of the categorical density estimator network to compute the
conditional density at a given time.

Args:
input: categories to evaluate.
condition: parameters.
input: Inputs datapoints of shape `(batch_size, *input_shape)`
condition: Conditioning variable. `(batch_size, *condition_shape)`

Returns:
Tensor: log probs with shape (input.shape[0],)
Predicted categorical logits. `(batch_size, *input_shape,
num_categories)`
"""
# Predict categorical ps and evaluate.
ps = self.forward(condition)
# Squeeze the last dimension (event dim) because `Categorical` has
# `event_shape=()` but our data usually has an event_shape of `(1,)`.
return Categorical(probs=ps).log_prob(input.squeeze(dim=-1))
embedded_condition = self.embedding_net.forward(condition)
out = super().forward(input, context=embedded_condition)
return out.masked_fill(~self.mask.bool().flatten(), float("-inf"))

def sample(self, sample_shape: torch.Size, condition: Tensor) -> Tensor:
"""Returns samples from categorical random variable with probs predicted from
the neural net.
def log_prob(self, input: Tensor, condition: Optional[Tensor] = None) -> Tensor:
r"""Return log-probability of samples.

Args:
sample_shape: number of samples to obtain.
condition: batch of parameters for prediction.
input: Input datapoints of shape `(batch_size, *input_shape)`.
condition: Conditioning variable. `(batch_size, *condition_shape)`.

Returns:
Tensor: Samples with shape (num_samples, 1)
Log-probabilities of shape `(batch_size, num_variables, num_categories)`.
"""
outputs = self.forward(input, condition=condition)
outputs = outputs.reshape(*input.shape, self.num_categories)
log_prob = Categorical(logits=outputs).log_prob(input).sum(dim=-1)
return log_prob

def sample(self, sample_shape, context=None):
jnsbck marked this conversation as resolved.
Show resolved Hide resolved
# Ensure sample_shape is a tuple
if isinstance(sample_shape, int):
sample_shape = (sample_shape,)

Check warning on line 110 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L110

Added line #L110 was not covered by tests
sample_shape = torch.Size(sample_shape)

# Calculate total number of samples
num_samples = int(torch.prod(torch.tensor(sample_shape)))

# Prepare context
if context is not None:
batch_dim = context.shape[0]
if context.ndim == 2:
context = context.unsqueeze(0)
if batch_dim == 1:
context = torchutils.repeat_rows(context, num_samples)
else:
context_dim = 0 if self.context_features is None else self.context_features
context = torch.zeros(num_samples, context_dim)
batch_dim = 1

Check warning on line 126 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L124-L126

Added lines #L124 - L126 were not covered by tests

with torch.no_grad():
samples = torch.randn(num_samples, batch_dim, self.num_variables)
for i in range(self.num_variables):
outputs = self.forward(samples, context)
outputs = outputs.reshape(*samples.shape, self.num_categories)
samples[:, :, : i + 1] = Categorical(
logits=outputs[:, :, : i + 1]
).sample()

# Predict Categorical ps and sample.
ps = self.forward(condition)
return Categorical(probs=ps).sample(sample_shape=sample_shape)
return samples.reshape(*sample_shape, batch_dim, self.num_variables)


class CategoricalMassEstimator(ConditionalDensityEstimator):
Expand All @@ -117,12 +144,12 @@
"""

def __init__(
self, net: CategoricalNet, input_shape: torch.Size, condition_shape: torch.Size
self, net: CategoricalMADE, input_shape: torch.Size, condition_shape: torch.Size
) -> None:
"""Initialize the mass estimator.

Args:
net: CategoricalNet.
net: CategoricalMADE.
input_shape: Shape of the input data.
condition_shape: Shape of the condition data
"""
Expand Down
14 changes: 11 additions & 3 deletions sbi/neural_nets/estimators/mixed_density_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def sample(
sample_shape=sample_shape,
condition=condition,
)
# Trailing `1` because `Categorical` has event_shape `()`.
discrete_samples = discrete_samples.reshape(num_samples * batch_dim, 1)
num_variables = self.discrete_net.net.num_variables
discrete_samples = discrete_samples.reshape(
num_samples * batch_dim, num_variables
)
jnsbck marked this conversation as resolved.
Show resolved Hide resolved

# repeat the batch of embedded condition to match number of choices.
condition_event_dim = embedded_condition.dim() - 1
Expand Down Expand Up @@ -145,7 +147,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
f"{input_batch_dim} do not match."
)

cont_input, disc_input = _separate_input(input)
num_discrete_variables = self.discrete_net.net.num_variables
cont_input, disc_input = _separate_input(input, num_discrete_variables)
# Embed continuous condition
embedded_condition = self.condition_embedding(condition)
# expand and repeat to match batch of inputs.
Expand Down Expand Up @@ -204,3 +207,8 @@ def _separate_input(
Assumes the discrete data to live in the last columns of input.
"""
return input[..., :-num_discrete_columns], input[..., -num_discrete_columns:]


def _is_discrete(input: Tensor) -> Tensor:
"""Infer discrete columns in input data."""
return torch.tensor([torch.allclose(col, col.round()) for col in input.T])
42 changes: 28 additions & 14 deletions sbi/neural_nets/net_builders/categorial.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import warnings
from typing import Optional

from torch import Tensor, nn, unique
from torch import Tensor, nn, tensor, unique

from sbi.neural_nets.estimators import CategoricalMassEstimator, CategoricalNet
from sbi.utils.nn_utils import get_numel
from sbi.utils.sbiutils import (
standardizing_net,
z_score_parser,
from sbi.neural_nets.estimators import (
CategoricalMADE,
CategoricalMassEstimator,
)
from sbi.neural_nets.estimators.mixed_density_estimator import _is_discrete
from sbi.utils.nn_utils import get_numel
from sbi.utils.sbiutils import standardizing_net, z_score_parser
from sbi.utils.user_input_checks import check_data_device


Expand All @@ -21,6 +23,7 @@ def build_categoricalmassestimator(
z_score_y: Optional[str] = "independent",
num_hidden: int = 20,
num_layers: int = 2,
num_categories: Optional[Tensor] = None,
embedding_net: nn.Module = nn.Identity(),
):
"""Returns a density estimator for a categorical random variable.
Expand All @@ -33,28 +36,39 @@ def build_categoricalmassestimator(
num_hidden: Number of hidden units per layer.
num_layers: Number of hidden layers.
embedding_net: Embedding net for y.
num_categories: number of categories for each variable.
"""

if z_score_x != "none":
raise ValueError("Categorical input should not be z-scored.")
if num_categories is None:
warnings.warn(
"Inferring num_categories from batch_x. Ensure all categories are present.",
stacklevel=2,
)

check_data_device(batch_x, batch_y)
if batch_x.shape[1] > 1:
raise NotImplementedError("CategoricalMassEstimator only supports 1D input.")
num_categories = unique(batch_x).numel()
dim_condition = get_numel(batch_y, embedding_net=embedding_net)

z_score_y_bool, structured_y = z_score_parser(z_score_y)
y_numel = get_numel(batch_y, embedding_net=embedding_net)

if z_score_y_bool:
embedding_net = nn.Sequential(
standardizing_net(batch_y, structured_y), embedding_net
)

categorical_net = CategoricalNet(
num_input=dim_condition,
if num_categories is None:
batch_x_discrete = batch_x[:, _is_discrete(batch_x)]
inferred_categories = tensor([
unique(col).numel() for col in batch_x_discrete.T
])
num_categories = inferred_categories

categorical_net = CategoricalMADE(
num_categories=num_categories,
num_hidden=num_hidden,
num_layers=num_layers,
hidden_features=num_hidden,
context_features=y_numel,
num_blocks=num_layers,
embedding_net=embedding_net,
)

Expand Down
Loading
Loading