Skip to content

Commit

Permalink
fix: posterior potential iid handling (#1276)
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb authored Sep 12, 2024
1 parent 25dd902 commit 9e5bed4
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions sbi/inference/potentials/posterior_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ def set_x(self, x_o: Optional[Tensor], x_is_iid: Optional[bool] = False):
For posterior-based methods, `x_o` is not allowed to be iid, as we assume that
iid `x` is handled by a Permutation Invariant embedding net.
"""
if x_is_iid:
if x_is_iid and x_o is not None and x_o.shape[0] > 1:
raise NotImplementedError(
"For NPE, iid `x` must be handled by a Permutation Invariant embedding \
net. Therefore, the iid dimension of `x` is added to the event\
dimension of `x`. Please set `x_is_iid=False`."
"For NPE, iid `x` must be handled by a permutation invariant embedding "
"net. Therefore, the iid dimension of `x` is added to the event "
"dimension of `x`. Please set `x_is_iid=False`."
)
else:
super().set_x(x_o, x_is_iid=False)
Expand Down
4 changes: 2 additions & 2 deletions tests/linearGaussian_fmpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ def simulator(theta):
# Evaluate the conditional density be drawing samples and smoothing with a Gaussian
# kde.
potential_fn, theta_transform = posterior_estimator_based_potential(
posterior_estimator, prior=prior, x_o=x_o
)
posterior_estimator, prior=prior
).set_x(x_o, x_is_iid=False)
(
conditioned_potential_fn,
restricted_tf,
Expand Down
4 changes: 2 additions & 2 deletions tests/linearGaussian_snpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,8 @@ def simulator(theta):
# Evaluate the conditional density be drawing samples and smoothing with a Gaussian
# kde.
potential_fn, theta_transform = posterior_estimator_based_potential(
posterior_estimator, prior=prior, x_o=x_o
)
posterior_estimator, prior=prior
).set_x(x_o, x_is_iid=False)
(
conditioned_potential_fn,
restricted_tf,
Expand Down

0 comments on commit 9e5bed4

Please sign in to comment.