Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: verify ND training is working with CatMADE.
Browse files Browse the repository at this point in the history
jnsbck committed Sep 13, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 089d1d3 commit 045bf5e
Showing 4 changed files with 68 additions and 44 deletions.
44 changes: 23 additions & 21 deletions sbi/neural_nets/estimators/categorical_net.py
Original file line number Diff line number Diff line change
@@ -4,21 +4,20 @@
from typing import Optional

import torch
from torch import Tensor, nn, distributions
from nflows.nn.nde.made import MADE
from nflows.utils import torchutils
from torch import Tensor, nn
from torch.distributions import Categorical
from torch.nn import Sigmoid, Softmax
from torch.nn import functional as F

from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
from nflows.nn.nde.made import MADE
from torch.nn import functional as F
from nflows.utils import torchutils
import numpy as np


class CategoricalMADE(MADE):
def __init__(
self,
categories, # List[int] or Tensor[int]
categories, # Tensor[int]
hidden_features,
context_features=None,
num_blocks=2,
@@ -31,7 +30,6 @@ def __init__(
custom_initialization=True,
embedding_net: Optional[nn.Module] = nn.Identity(),
):

if use_residual_blocks and random_mask:
raise ValueError("Residual blocks can't be used with random masks.")

Check warning on line 34 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L33-L34

Added lines #L33 - L34 were not covered by tests

@@ -65,54 +63,57 @@ def __init__(
def forward(self, inputs, context=None):
embedded_inputs = self.embedding_net.forward(inputs)
return super().forward(embedded_inputs, context=context)

Check warning on line 65 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L64-L65

Added lines #L64 - L65 were not covered by tests

def compute_probs(self, outputs):
ps = F.softmax(outputs, dim=-1)*self.mask
ps = F.softmax(outputs, dim=-1) * self.mask
ps = ps / ps.sum(dim=-1, keepdim=True)
return ps.squeeze(-2)
return ps

Check warning on line 70 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L68-L70

Added lines #L68 - L70 were not covered by tests

# outputs (batch_size, num_variables, num_categories)
def log_prob(self, inputs, context=None):
outputs = self.forward(inputs, context=context)
outputs = outputs.reshape(*inputs.shape, self.num_categories)
ps = self.compute_probs(outputs)

Check warning on line 76 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L74-L76

Added lines #L74 - L76 were not covered by tests

# categorical log prob
log_prob = torch.log(ps.gather(-1, inputs.long()))
log_prob = log_prob.sum(dim=-1)
log_prob = torch.log(ps.gather(-1, inputs.unsqueeze(-1).long()))
log_prob = log_prob.squeeze(-1).sum(dim=-1)

Check warning on line 80 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L79-L80

Added lines #L79 - L80 were not covered by tests

return log_prob

Check warning on line 82 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L82

Added line #L82 was not covered by tests

def sample(self, sample_shape, context=None):
# Ensure sample_shape is a tuple
if isinstance(sample_shape, int):
sample_shape = (sample_shape,)
sample_shape = torch.Size(sample_shape)

Check warning on line 88 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L86-L88

Added lines #L86 - L88 were not covered by tests

# Calculate total number of samples
num_samples = torch.prod(torch.tensor(sample_shape)).item()

Check warning on line 91 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L91

Added line #L91 was not covered by tests

# Prepare context
if context is not None:
if context.ndim == 1:
context = context.unsqueeze(0)
context = torchutils.repeat_rows(context, num_samples)

Check warning on line 97 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L94-L97

Added lines #L94 - L97 were not covered by tests
else:
context = torch.zeros(num_samples, self.context_dim)

Check warning on line 99 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L99

Added line #L99 was not covered by tests

with torch.no_grad():
samples = torch.zeros(num_samples, self.num_variables)
for variable in range(self.num_variables):
outputs = self.forward(samples, context)
outputs = outputs.reshape(num_samples, self.num_variables, self.num_categories)
outputs = outputs.reshape(

Check warning on line 105 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L101-L105

Added lines #L101 - L105 were not covered by tests
num_samples, self.num_variables, self.num_categories
)
ps = self.compute_probs(outputs)
samples[:, variable] = Categorical(probs=ps[:,variable]).sample()
samples[:, variable] = Categorical(probs=ps[:, variable]).sample()

Check warning on line 109 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L108-L109

Added lines #L108 - L109 were not covered by tests

return samples.reshape(*sample_shape, self.num_variables)

Check warning on line 111 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L111

Added line #L111 was not covered by tests

def _initialize(self):
pass

Check warning on line 114 in sbi/neural_nets/estimators/categorical_net.py

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L114

Added line #L114 was not covered by tests


class CategoricalNet(nn.Module):
"""Conditional density (mass) estimation for a categorical random variable.
@@ -145,6 +146,7 @@ def __init__(
self.activation = Sigmoid()
self.softmax = Softmax(dim=1)
self.num_categories = num_categories
self.num_variables = 1

# Maybe add embedding net in front.
if embedding_net is not None:
14 changes: 11 additions & 3 deletions sbi/neural_nets/estimators/mixed_density_estimator.py
Original file line number Diff line number Diff line change
@@ -80,8 +80,10 @@ def sample(
sample_shape=sample_shape,
condition=condition,
)
# Trailing `1` because `Categorical` has event_shape `()`.
discrete_samples = discrete_samples.reshape(num_samples * batch_dim, 1)
num_variables = self.discrete_net.net.num_variables
discrete_samples = discrete_samples.reshape(
num_samples * batch_dim, num_variables
)

# repeat the batch of embedded condition to match number of choices.
condition_event_dim = embedded_condition.dim() - 1
@@ -145,7 +147,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
f"{input_batch_dim} do not match."
)

cont_input, disc_input = _separate_input(input)
num_disc = self.discrete_net.net.num_variables
cont_input, disc_input = _separate_input(input, num_discrete_columns=num_disc)
# Embed continuous condition
embedded_condition = self.condition_embedding(condition)
# expand and repeat to match batch of inputs.
@@ -204,3 +207,8 @@ def _separate_input(
Assumes the discrete data to live in the last columns of input.
"""
return input[..., :-num_discrete_columns], input[..., -num_discrete_columns:]


def _is_discrete(input: Tensor) -> Tensor:
"""Infer discrete columns in input data."""
return torch.tensor([torch.allclose(col, col.round()) for col in input.T])
31 changes: 20 additions & 11 deletions sbi/neural_nets/net_builders/categorial.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import warnings
from typing import Optional

from torch import Tensor, nn, unique, tensor
from torch import Tensor, nn, tensor, unique

from sbi.neural_nets.estimators import CategoricalMassEstimator, CategoricalNet, CategoricalMADE
from sbi.utils.nn_utils import get_numel
from sbi.utils.sbiutils import (
standardizing_net,
z_score_parser,
from sbi.neural_nets.estimators import (
CategoricalMADE,
CategoricalMassEstimator,
CategoricalNet,
)
from sbi.neural_nets.estimators.mixed_density_estimator import _is_discrete
from sbi.utils.nn_utils import get_numel
from sbi.utils.sbiutils import standardizing_net, z_score_parser
from sbi.utils.user_input_checks import check_data_device


@@ -70,7 +73,7 @@ def build_autoregressive_categoricalmassestimator(
z_score_y: Optional[str] = "independent",
num_hidden: int = 20,
num_layers: int = 2,
num_variables: int = 1,
categories: Optional[Tensor] = None,
embedding_net: nn.Module = nn.Identity(),
):
"""Returns a density estimator for a categorical random variable.
@@ -87,6 +90,11 @@ def build_autoregressive_categoricalmassestimator(

if z_score_x != "none":
raise ValueError("Categorical input should not be z-scored.")
if categories is None:
warnings.warn(

Check warning on line 94 in sbi/neural_nets/net_builders/categorial.py

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L91-L94

Added lines #L91 - L94 were not covered by tests
"Inferring categories from batch_x. Ensure all categories are present.",
stacklevel=2,
)

check_data_device(batch_x, batch_y)

Check warning on line 99 in sbi/neural_nets/net_builders/categorial.py

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L99

Added line #L99 was not covered by tests

@@ -98,17 +106,18 @@ def build_autoregressive_categoricalmassestimator(
standardizing_net(batch_y, structured_y), embedding_net
)

categories = tensor([unique(variable).numel() for variable in batch_x.T])
categories = categories[-num_variables:]
batch_x_discrete = batch_x[:, _is_discrete(batch_x)]
inferred_categories = tensor([unique(col).numel() for col in batch_x_discrete.T])
categories = categories if categories is not None else inferred_categories

Check warning on line 111 in sbi/neural_nets/net_builders/categorial.py

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L109-L111

Added lines #L109 - L111 were not covered by tests

categorical_net = CategoricalMADE(

Check warning on line 113 in sbi/neural_nets/net_builders/categorial.py

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L113

Added line #L113 was not covered by tests
categories=categories,
hidden_features=num_hidden,
context_features=y_numel,
num_blocks=num_layers,
# TODO: embedding_net=embedding_net,
embedding_net=embedding_net,
)

return CategoricalMassEstimator(

Check warning on line 121 in sbi/neural_nets/net_builders/categorial.py

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L121

Added line #L121 was not covered by tests
categorical_net, input_shape=batch_x[0].shape, condition_shape=batch_y[0].shape
)
)
23 changes: 14 additions & 9 deletions sbi/neural_nets/net_builders/mnle.py
Original file line number Diff line number Diff line change
@@ -8,8 +8,14 @@
from torch import Tensor, nn

from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.neural_nets.estimators.mixed_density_estimator import _separate_input
from sbi.neural_nets.net_builders.categorial import build_categoricalmassestimator, build_autoregressive_categoricalmassestimator
from sbi.neural_nets.estimators.mixed_density_estimator import (
_is_discrete,
_separate_input,
)
from sbi.neural_nets.net_builders.categorial import (
build_autoregressive_categoricalmassestimator,
build_categoricalmassestimator,
)
from sbi.neural_nets.net_builders.flow import (
build_made,
build_maf,
@@ -26,10 +32,7 @@
build_zuko_unaf,
)
from sbi.neural_nets.net_builders.mdn import build_mdn
from sbi.utils.sbiutils import (
standardizing_net,
z_score_parser,
)
from sbi.utils.sbiutils import standardizing_net, z_score_parser
from sbi.utils.user_input_checks import check_data_device

model_builders = {
@@ -128,13 +131,14 @@ def build_mnle(

warnings.warn(
"The mixed neural likelihood estimator assumes that x contains "
"continuous data in the first n-1 columns (e.g., reaction times) and "
"categorical data in the last column (e.g., corresponding choices). If "
"continuous data in the first n-k columns (e.g., reaction times) and "
"categorical data in the last k columns (e.g., corresponding choices). If "
"this is not the case for the passed `x` do not use this function.",
stacklevel=2,
)
# Separate continuous and discrete data.
cont_x, disc_x = _separate_input(batch_x)
num_disc = int(torch.sum(_is_discrete(batch_x)))
cont_x, disc_x = _separate_input(batch_x, num_discrete_columns=num_disc)

# Set up y-embedding net with z-scoring.
z_score_y_bool, structured_y = z_score_parser(z_score_y)
@@ -158,6 +162,7 @@ def build_mnle(
embedding_net=embedding_net,
)
elif categorical_model == "mlp":
assert num_disc == 1, "MLP only supports 1D input."
discrete_net = build_categoricalmassestimator(
disc_x,
batch_y,

0 comments on commit 045bf5e

Please sign in to comment.