Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNPE class similar to MNLE #1362

Merged
merged 21 commits into from
Mar 20, 2025
Merged

MNPE class similar to MNLE #1362

merged 21 commits into from
Mar 20, 2025

Conversation

dgedon
Copy link
Collaborator

@dgedon dgedon commented Jan 10, 2025

Implementation of mixed NPE where we have some continuous parameters theta followed by one (or multiple with this PR #1269) discrete parameters. The observation space is fully continuous.

Deprecated mnle.py in net_builders und unified mnle/mnpe as mixed_nets.py.

@dgedon
Copy link
Collaborator Author

dgedon commented Jan 17, 2025

Update:

  • MNPE and tests are implemented
  • for test with Bernoulli prior, I had to change mcmc_transforms to handle discrete distirbution. As default we just compute mean/std for discrete distributions
  • currently MNPE with embedding nets does not work yet. Gives some backwards inplace operations error that I couldn't solve yet.

@janfb
Copy link
Contributor

janfb commented Feb 25, 2025

@dgedon #1269 is now merged 🙌

@dgedon
Copy link
Collaborator Author

dgedon commented Mar 18, 2025

Updates:

  • bug fix so everything works now. Essentially need to handle normalization with care when switching from mnle to mnpe
  • remove unnecessary GPU handling (hackathon task). This limits MultipleIndependent as not allowing device argument yet

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good overall, except one central question about the call signature of MNPE.

Comment on lines 16 to 19
This estimator combines a Categorical net and a neural density estimator to model
data with mixed types (discrete and continuous), e.g., as they occur in
decision-making models. It can be used for both likelihood and posterior estimation
of mixed data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This estimator combines a Categorical net and a neural density estimator to model
data with mixed types (discrete and continuous), e.g., as they occur in
decision-making models. It can be used for both likelihood and posterior estimation
of mixed data.
This estimator combines a categorical mass estimator and a density estimator to model
variables with mixed types (discrete and continuous). It can be used for both likelihood
estimation (e.g., for discrete decisions and continuous reaction times in decision-making
models) or posterior estimation (e.g., for models that have both discrete and continuous
parameters).

"""The forward method is not implemented for MNLE, use '.sample(...)' to
generate samples though a forward pass."""
"""The forward method is not implemented for mixed neural density
estimation,use '.sample(...)' to generate samples though a forward
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
estimation,use '.sample(...)' to generate samples though a forward
estimation, use '.sample(...)' to generate samples though a forward

Comment on lines 228 to 257
def build_mnle(
batch_x: Tensor,
batch_y: Tensor,
**kwargs,
) -> MixedDensityEstimator:
"""Returns a mixed neural likelihood estimator.

This estimator models p(x|theta) where x contains both continuous and discrete data.

Args:
batch_x: Batch of xs (data), used to infer dimensionality.
batch_y: Batch of ys (parameters), used to infer dimensionality.
**kwargs: Additional arguments passed to _build_mixed_density_estimator.

Returns:
MixedDensityEstimator for MNLE.
"""
return _build_mixed_density_estimator(
batch_x=batch_x, batch_y=batch_y, mode="mnle", **kwargs
)


def build_mnpe(
batch_x: Tensor,
batch_y: Tensor,
**kwargs,
) -> MixedDensityEstimator:
"""Returns a mixed neural posterior estimator.

This estimator models p(theta|x) where x contains both continuous and discrete data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused by these two call functions. maybe I am missing something, but couldn't we just call _build_mixed_density_estimator with batch_x and batch_y swapped for MNPE and MNLE?

To me it seems this swapping is not happening, i.e., we need to make sure that in MNPE we are only embedding x and not theta, and v.v. in MNLE. Let's discuss tomorrow.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, we could remove both functions and just use _build_mixed_density_estimator.

The swapping of x/theta is happening because the function is once called as likelihood_nn and once as posterior_nn.

@dgedon
Copy link
Collaborator Author

dgedon commented Mar 19, 2025

Update:

  • simplify _build_mixed_density_estimator by not having mode='mnpe'/'mnle'
  • add default log_transform_x as kwarg to build_mnle and build_mnpe

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some more comments on the tests and the refactoring.

I am suggestion a toy example with ground-truth posterior for the MNPE scenario to test the accuracy.

Comment on lines 729 to 741
try:
prior_mean = prior.mean.to(device)
prior_std = prior.stddev.to(device)
except (NotImplementedError, AttributeError):
warnings.warn(
"The passed discrete prior has no mean or stddev attribute, "
"estimating them from samples to build affine standardizing "
"transform.",
stacklevel=2,
)
theta = prior.sample(torch.Size((num_prior_samples_for_zscoring,)))
prior_mean = theta.mean(dim=0).to(device)
prior_std = theta.std(dim=0).to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this into a small function to avoid code duplication?

x = mixed_param_simulator(theta)

# Build estimator manually
theta_embedding = FCEmbedding(1, 1) # simple embedding net, 1 continuous parameter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be an x_embedding to avoid confusion

log_transform_x=False,
)
trainer = MNPE(density_estimator=density_estimator)
trainer.append_simulations(theta, x).train(max_num_epochs=5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_num_epochs=1

to speed up tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove all diffs here. we probably want to have some kind of tutorial or how-to-guide with MNPE, but let's wait for the new documentation setup.

@dgedon
Copy link
Collaborator Author

dgedon commented Mar 20, 2025

Update:
added an accuracy test for MNPE consisting of 2 Gaussians with varying mean. The observation is one Gaussian based on a Bernoulli "selection" variable. The analytic reference posterior is compared to the MNPE estimate using C2ST.

Copy link

codecov bot commented Mar 20, 2025

Codecov Report

Attention: Patch coverage is 83.92857% with 9 lines in your changes missing coverage. Please review.

Project coverage is 79.09%. Comparing base (46ccee0) to head (bfaabe8).
Report is 12 commits behind head on main.

Files with missing lines Patch % Lines
sbi/utils/sbiutils.py 69.56% 7 Missing ⚠️
sbi/inference/trainers/npe/mnpe.py 95.23% 1 Missing ⚠️
sbi/utils/user_input_checks_utils.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1362       +/-   ##
===========================================
- Coverage   89.62%   79.09%   -10.53%     
===========================================
  Files         121      122        +1     
  Lines        9347     9394       +47     
===========================================
- Hits         8377     7430      -947     
- Misses        970     1964      +994     
Flag Coverage Δ
unittests 79.09% <83.92%> (-10.81%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/inference/__init__.py 100.00% <100.00%> (ø)
sbi/inference/trainers/npe/__init__.py 100.00% <100.00%> (ø)
.../neural_nets/estimators/mixed_density_estimator.py 94.73% <ø> (-1.76%) ⬇️
sbi/neural_nets/factory.py 90.90% <100.00%> (ø)
sbi/neural_nets/net_builders/__init__.py 100.00% <100.00%> (ø)
sbi/neural_nets/net_builders/mixed_nets.py 97.05% <100.00%> (ø)
sbi/inference/trainers/npe/mnpe.py 95.23% <95.23%> (ø)
sbi/utils/user_input_checks_utils.py 88.67% <50.00%> (-1.13%) ⬇️
sbi/utils/sbiutils.py 78.38% <69.56%> (-9.16%) ⬇️

... and 38 files with indirect coverage changes

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great to have the test now! 🎉

added a some comments for the mnle defaults and some suggestions for the test.

@dgedon dgedon requested a review from janfb March 20, 2025 11:10
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

two final comments.

Comment on lines 81 to 87
theta_true = torch.cat(
(
torch.rand(batch_size, 1),
torch.ones(batch_size, 1),
torch.rand(batch_size, 2),
torch.bernoulli(0.8 * torch.ones(batch_size, 1)),
),
dim=1,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could just do prior.sample((1,)) no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we usually call it theta_o :)

@dgedon dgedon requested review from janfb and removed request for michaeldeistler March 20, 2025 11:32
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two final final comments. Good to be merged afterwards. Really cool to have this implemented now 🔥 🎉

@dgedon dgedon merged commit 312e9ef into sbi-dev:main Mar 20, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants