Skip to content

Commit

Permalink
refactor: remove deprecated x_shape where not needed. (#1271)
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb authored Sep 11, 2024
1 parent dbcd06b commit 299854e
Show file tree
Hide file tree
Showing 5 changed files with 1 addition and 10 deletions.
1 change: 0 additions & 1 deletion sbi/inference/posteriors/ensemble_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def __init__(
potential_fn=potential_fn,
theta_transform=theta_transform,
device=device,
x_shape=None,
)

def ensure_same_device(self, posteriors: List) -> str:
Expand Down
2 changes: 0 additions & 2 deletions sbi/inference/trainers/fmpe/fmpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
npe_msg_on_invalid_x,
validate_theta_and_x,
warn_if_zscoring_changes_data,
x_shape_from_simulation,
)
from sbi.utils.sbiutils import mask_sims_from_prior

Expand Down Expand Up @@ -199,7 +198,6 @@ def train(
theta[self.train_indices].to("cpu"),
x[self.train_indices].to("cpu"),
)
self._x_shape = x_shape_from_simulation(x.to("cpu"))

del theta, x

Expand Down
2 changes: 0 additions & 2 deletions sbi/inference/trainers/npse/npse.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
test_posterior_net_for_multi_d_x,
validate_theta_and_x,
warn_if_zscoring_changes_data,
x_shape_from_simulation,
)
from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior

Expand Down Expand Up @@ -282,7 +281,6 @@ def default_calibration_kernel(x):
theta[self.train_indices].to("cpu"),
x[self.train_indices].to("cpu"),
)
self._x_shape = x_shape_from_simulation(x.to("cpu"))

test_posterior_net_for_multi_d_x(
self._neural_net,
Expand Down
2 changes: 0 additions & 2 deletions sbi/inference/trainers/nre/nre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
check_estimator_arg,
check_prior,
clamp_and_warn,
x_shape_from_simulation,
)
from sbi.utils.torchutils import repeat_rows

Expand Down Expand Up @@ -203,7 +202,6 @@ def train(
theta[self.train_indices].to("cpu"),
x[self.train_indices].to("cpu"),
)
self._x_shape = x_shape_from_simulation(x.to("cpu"))
del x, theta
self._neural_net.to(self._device)

Expand Down
4 changes: 1 addition & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def __init__(
potential_fn: Callable,
theta_transform: Optional[TorchTransform] = None,
device: Optional[str] = "cpu",
x_shape: Optional[torch.Size] = None,
):
"""
Args:
Expand All @@ -228,10 +227,9 @@ def __init__(
Allows to perform, e.g. MCMC in unconstrained space.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Shape of the observed data.
"""
assert isinstance(potential_fn, PosteriorPotential)
super().__init__(potential_fn, theta_transform, device, x_shape)
super().__init__(potential_fn, theta_transform, device)

def sample(
self,
Expand Down

0 comments on commit 299854e

Please sign in to comment.