Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add method for iid-batched conditioning. #1331

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

janfb
Copy link
Contributor

@janfb janfb commented Dec 12, 2024

  • introduce condition_on method for likelihood-estimator-based potential to condition on a batch of theta values (experimental conditions) that match the current batch of iid x_o.
  • deprecate MNLE-based potential (can be nle-based)
  • adapt tests for conditioned mnle.

- deprecate MNLE-based potential (can be nle-based)
- adapt tests for conditioned mnle.
Copy link

codecov bot commented Dec 12, 2024

Codecov Report

Attention: Patch coverage is 26.08696% with 17 lines in your changes missing coverage. Please review.

Project coverage is 78.30%. Comparing base (06890eb) to head (074efa0).

Files with missing lines Patch % Lines
...inference/potentials/likelihood_based_potential.py 19.04% 17 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1331       +/-   ##
===========================================
- Coverage   89.39%   78.30%   -11.09%     
===========================================
  Files         118      118               
  Lines        8709     8736       +27     
===========================================
- Hits         7785     6841      -944     
- Misses        924     1895      +971     
Flag Coverage Δ
unittests 78.30% <26.08%> (-11.09%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/inference/trainers/nle/mnle.py 85.36% <100.00%> (-7.32%) ⬇️
sbi/utils/conditional_density_utils.py 73.46% <100.00%> (-21.09%) ⬇️
sbi/utils/sbiutils.py 78.11% <ø> (-8.68%) ⬇️
...inference/potentials/likelihood_based_potential.py 60.00% <19.04%> (-40.00%) ⬇️

... and 29 files with indirect coverage changes

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I fully understand tbh.

IIUC, then there are four different dimensions across which we could vectorize:

  1. Across the x batch-dimension, in case we want to amortize.
  2. Across the x iid-dimension, in case we have iid samples.
  3. Across a batch of theta conditions which act as different experimental conditions
  4. Across thetas, in case we want to run multi-chain.

Which one of these does this PR solve? And which ones does it not yet solve?

Can we maybe somewhere define names for each of these four dimensions above and specify which ones are handled by which function (or which ones are currently assumed to be the same dimension and therefore do not have a cartesian product applied to them)?

Args:
x: Batch of iid data of shape `(iid_dim, *event_shape)`.
theta_without_condition: Batch of parameters `(batch_dim, *event_shape)`
condition: Batch of conditions of shape `(iid_dim, *condition_shape)`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition is in theta space, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iid_dim is the same as for x, right?


Args:
x: Batch of iid data of shape `(iid_dim, *event_shape)`.
theta_without_condition: Batch of parameters `(batch_dim, *event_shape)`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be useful to define event_shape more explicitly here.

likelihoods.

Args:
x: Batch of iid data of shape `(iid_dim, *event_shape)`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename event_shape to x_event_shape.


Returns:
log_likelihood_trial_sum: log likelihood for each parameter, summed over all
batch entries (iid trials) in `x`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please specify shape.

x_o.

Args:
condition: The condition to fix.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the condition be the full theta or only the ones that are not part of dims_to_sample?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sentence above implies that it is only "a part of theta"

"""Returns a potential conditioned on a subset of theta dimensions.

The condition is a part of theta, but is assumed to correspond to a batch of iid
x_o.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the batch_shape of condition must match the batch_shape of x later on?

condition.repeat_interleave(num_theta, dim=0), # repeat AABB
],
dim=-1,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is nasty, but I also don't think that there is another way to do this.

@janfb
Copy link
Contributor Author

janfb commented Dec 13, 2024

As discussed in the call, the plan is to

  • introduce (sample, batch, *event) for the theta_condition at inference time
  • this will enable sampling for multiple xs (batched x) multiple iid observations (for each x) and simultaneously passing a corresponding batch of conditions for each x. For example, for a decision-making use-case, it will be possible to sample in one call: multiple subject (batch of xs) where each subject performed N trials with N (different) experimental conditions. N needs to be the same for all subjects though.
  • improve naming and docs to avoid confusion between condition and experimental_condition, e.g., use theta_condition
  • add tests for new conditioning functions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants