-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MNPE class similar to MNLE #1362
Conversation
Update:
|
…ODO: MNPE class + test
…dent; embedding net in mnpe not working yet
…not allow gpu handling yet though
Updates:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good overall, except one central question about the call signature of MNPE
.
This estimator combines a Categorical net and a neural density estimator to model | ||
data with mixed types (discrete and continuous), e.g., as they occur in | ||
decision-making models. It can be used for both likelihood and posterior estimation | ||
of mixed data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This estimator combines a Categorical net and a neural density estimator to model | |
data with mixed types (discrete and continuous), e.g., as they occur in | |
decision-making models. It can be used for both likelihood and posterior estimation | |
of mixed data. | |
This estimator combines a categorical mass estimator and a density estimator to model | |
variables with mixed types (discrete and continuous). It can be used for both likelihood | |
estimation (e.g., for discrete decisions and continuous reaction times in decision-making | |
models) or posterior estimation (e.g., for models that have both discrete and continuous | |
parameters). |
"""The forward method is not implemented for MNLE, use '.sample(...)' to | ||
generate samples though a forward pass.""" | ||
"""The forward method is not implemented for mixed neural density | ||
estimation,use '.sample(...)' to generate samples though a forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
estimation,use '.sample(...)' to generate samples though a forward | |
estimation, use '.sample(...)' to generate samples though a forward |
def build_mnle( | ||
batch_x: Tensor, | ||
batch_y: Tensor, | ||
**kwargs, | ||
) -> MixedDensityEstimator: | ||
"""Returns a mixed neural likelihood estimator. | ||
|
||
This estimator models p(x|theta) where x contains both continuous and discrete data. | ||
|
||
Args: | ||
batch_x: Batch of xs (data), used to infer dimensionality. | ||
batch_y: Batch of ys (parameters), used to infer dimensionality. | ||
**kwargs: Additional arguments passed to _build_mixed_density_estimator. | ||
|
||
Returns: | ||
MixedDensityEstimator for MNLE. | ||
""" | ||
return _build_mixed_density_estimator( | ||
batch_x=batch_x, batch_y=batch_y, mode="mnle", **kwargs | ||
) | ||
|
||
|
||
def build_mnpe( | ||
batch_x: Tensor, | ||
batch_y: Tensor, | ||
**kwargs, | ||
) -> MixedDensityEstimator: | ||
"""Returns a mixed neural posterior estimator. | ||
|
||
This estimator models p(theta|x) where x contains both continuous and discrete data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused by these two call functions. maybe I am missing something, but couldn't we just call _build_mixed_density_estimator
with batch_x
and batch_y
swapped for MNPE
and MNLE
?
To me it seems this swapping is not happening, i.e., we need to make sure that in MNPE
we are only embedding x
and not theta
, and v.v. in MNLE
. Let's discuss tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, we could remove both functions and just use _build_mixed_density_estimator
.
The swapping of x/theta is happening because the function is once called as likelihood_nn
and once as posterior_nn
.
Update:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some more comments on the tests and the refactoring.
I am suggestion a toy example with ground-truth posterior for the MNPE scenario to test the accuracy.
sbi/utils/sbiutils.py
Outdated
try: | ||
prior_mean = prior.mean.to(device) | ||
prior_std = prior.stddev.to(device) | ||
except (NotImplementedError, AttributeError): | ||
warnings.warn( | ||
"The passed discrete prior has no mean or stddev attribute, " | ||
"estimating them from samples to build affine standardizing " | ||
"transform.", | ||
stacklevel=2, | ||
) | ||
theta = prior.sample(torch.Size((num_prior_samples_for_zscoring,))) | ||
prior_mean = theta.mean(dim=0).to(device) | ||
prior_std = theta.std(dim=0).to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this into a small function to avoid code duplication?
tests/mnpe_test.py
Outdated
x = mixed_param_simulator(theta) | ||
|
||
# Build estimator manually | ||
theta_embedding = FCEmbedding(1, 1) # simple embedding net, 1 continuous parameter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be an x_embedding
to avoid confusion
tests/mnpe_test.py
Outdated
log_transform_x=False, | ||
) | ||
trainer = MNPE(density_estimator=density_estimator) | ||
trainer.append_simulations(theta, x).train(max_num_epochs=5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_num_epochs=1
to speed up tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove all diffs here. we probably want to have some kind of tutorial or how-to-guide with MNPE
, but let's wait for the new documentation setup.
Update: |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1362 +/- ##
===========================================
- Coverage 89.62% 79.09% -10.53%
===========================================
Files 121 122 +1
Lines 9347 9394 +47
===========================================
- Hits 8377 7430 -947
- Misses 970 1964 +994
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great to have the test now! 🎉
added a some comments for the mnle defaults and some suggestions for the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good!
two final comments.
tests/mnpe_test.py
Outdated
theta_true = torch.cat( | ||
( | ||
torch.rand(batch_size, 1), | ||
torch.ones(batch_size, 1), | ||
torch.rand(batch_size, 2), | ||
torch.bernoulli(0.8 * torch.ones(batch_size, 1)), | ||
), | ||
dim=1, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could just do prior.sample((1,))
no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we usually call it theta_o
:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two final final comments. Good to be merged afterwards. Really cool to have this implemented now 🔥 🎉
Implementation of mixed NPE where we have some continuous parameters theta followed by one (or multiple with this PR #1269) discrete parameters. The observation space is fully continuous.
Deprecated mnle.py in net_builders und unified mnle/mnpe as mixed_nets.py.