Skip to content

Commit

Permalink
refactor: remove simulate_for_sbi from tests (#1208)
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb authored Jul 30, 2024
1 parent 81fffcf commit b275448
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 42 deletions.
39 changes: 18 additions & 21 deletions tests/linearGaussian_snpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
MCMCPosterior,
RejectionPosterior,
posterior_estimator_based_potential,
simulate_for_sbi,
)
from sbi.neural_nets import posterior_nn
from sbi.simulators.linear_gaussian import (
Expand Down Expand Up @@ -176,9 +175,8 @@ def simulator(theta):

inference = SNPE_C(prior, density_estimator=density_estimator)

theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=1000
)
theta = prior.sample((num_simulations,))
x = simulator(theta)
posterior_estimator = inference.append_simulations(theta, x).train(
training_batch_size=100
)
Expand Down Expand Up @@ -315,14 +313,14 @@ def simulator(theta):

if method_str == "snpe_b":
inference = SNPE_B(**creation_args)
theta, x = simulate_for_sbi(simulator, prior, 500, simulation_batch_size=10)
theta = prior.sample((500,))
x = simulator(theta)
posterior_estimator = inference.append_simulations(theta, x).train()
posterior1 = DirectPosterior(
prior=prior, posterior_estimator=posterior_estimator
).set_default_x(x_o)
theta, x = simulate_for_sbi(
simulator, posterior1, 1000, simulation_batch_size=10
)
theta = posterior1.sample((1000,))
x = simulator(theta)
posterior_estimator = inference.append_simulations(
theta, x, proposal=posterior1
).train()
Expand All @@ -331,7 +329,8 @@ def simulator(theta):
).set_default_x(x_o)
elif method_str == "snpe_c":
inference = SNPE_C(**creation_args)
theta, x = simulate_for_sbi(simulator, prior, 900, simulation_batch_size=50)
theta = prior.sample((900,))
x = simulator(theta)
posterior_estimator = inference.append_simulations(theta, x).train()
posterior1 = DirectPosterior(
prior=prior, posterior_estimator=posterior_estimator
Expand All @@ -348,17 +347,17 @@ def simulator(theta):
for r in range(num_rounds):
if r == 2:
final_round = True
theta, x = simulate_for_sbi(
simulator, proposal, 500, simulation_batch_size=50
)
theta = proposal.sample((500,))
x = simulator(theta)
inference = inference.append_simulations(theta, x, proposal=proposal)
_ = inference.train(max_num_epochs=200, final_round=final_round)
posterior = inference.build_posterior().set_default_x(x_o)
proposal = posterior
elif method_str.startswith("tsnpe"):
sample_method = "rejection" if method_str == "tsnpe_rejection" else "sir"
inference = SNPE_C(**creation_args)
theta, x = simulate_for_sbi(simulator, prior, 900, simulation_batch_size=50)
theta = prior.sample((900,))
x = simulator(theta)
posterior_estimator = inference.append_simulations(theta, x).train()
posterior1 = DirectPosterior(
prior=prior, posterior_estimator=posterior_estimator
Expand Down Expand Up @@ -418,7 +417,8 @@ def simulator(theta):

inference = SNPE_C(prior, show_progress_bars=False)

theta, x = simulate_for_sbi(simulator, prior, 1000)
theta = prior.sample((1000,))
x = simulator(theta)
posterior_estimator = inference.append_simulations(theta, x).train()
potential_fn, theta_transform = posterior_estimator_based_potential(
posterior_estimator, prior, x_o
Expand Down Expand Up @@ -533,12 +533,8 @@ def simulator(theta):
inference = SNPE_C(prior, density_estimator=net, show_progress_bars=True)

# We need a pretty big dataset to properly model the bimodality.
theta, x = simulate_for_sbi(
simulator,
prior,
num_simulations,
simulation_batch_size=num_simulations,
)
theta = prior.sample((num_simulations,))
x = simulator(theta)
posterior_estimator = inference.append_simulations(theta, x).train(
training_batch_size=1000, max_num_epochs=60
)
Expand Down Expand Up @@ -721,7 +717,8 @@ def simulator(theta):
inference = SNPE_C(prior, density_estimator="mdn")

for _ in range(3):
theta, x = simulate_for_sbi(simulator, proposal, 200)
theta = proposal.sample((200,))
x = simulator(theta)
_ = inference.append_simulations(theta, x, proposal=proposal).train()
posterior = inference.build_posterior().set_default_x(x_o)
proposal = posterior
21 changes: 8 additions & 13 deletions tests/linearGaussian_snre_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
RejectionPosterior,
VIPosterior,
ratio_estimator_based_potential,
simulate_for_sbi,
)
from sbi.inference.snre.snre_base import RatioEstimator
from sbi.simulators.linear_gaussian import (
Expand Down Expand Up @@ -185,9 +184,8 @@ def simulator(theta):
inference = snre_method(**kwargs)

# Should use default `num_atoms=10` for SRE; `num_atoms=2` for AALR
theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=num_simulations
)
theta = prior.sample((num_simulations,))
x = simulator(theta)
ratio_estimator = inference.append_simulations(theta, x).train(**train_kwargs)
potential_fn, theta_transform = ratio_estimator_based_potential(
ratio_estimator=ratio_estimator, prior=prior, x_o=x_o
Expand Down Expand Up @@ -280,9 +278,8 @@ def simulator(theta):

inference = snre_method(show_progress_bars=False)

theta, x = simulate_for_sbi(
simulator, prior, num_simulations_per_round, simulation_batch_size=50
)
theta = prior.sample((num_simulations_per_round,))
x = simulator(theta)
ratio_estimator = inference.append_simulations(theta, x).train()
potential_fn, theta_transform = ratio_estimator_based_potential(
prior=prior, ratio_estimator=ratio_estimator, x_o=x_o
Expand All @@ -293,9 +290,8 @@ def simulator(theta):
)
posterior1.train()

theta, x = simulate_for_sbi(
simulator, posterior1, num_simulations_per_round, simulation_batch_size=50
)
theta = posterior1.sample((num_simulations_per_round,))
x = simulator(theta)
ratio_estimator = inference.append_simulations(theta, x).train()
potential_fn, theta_transform = ratio_estimator_based_potential(
prior=prior, ratio_estimator=ratio_estimator, x_o=x_o
Expand Down Expand Up @@ -377,9 +373,8 @@ def test_api_sre_sampling_methods(

inference = SNRE_B(classifier="resnet", show_progress_bars=False)

theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=num_simulations
)
theta = prior.sample((num_simulations,))
x = simulator(theta)
ratio_estimator = inference.append_simulations(theta, x).train(max_num_epochs=5)
potential_fn, theta_transform = ratio_estimator_based_potential(
ratio_estimator=ratio_estimator, prior=prior, x_o=x_o
Expand Down
11 changes: 5 additions & 6 deletions tests/sbc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from torch.distributions import MultivariateNormal, Uniform

from sbi.diagnostics import check_sbc, get_nltp, run_sbc
from sbi.inference import SNLE, SNPE, simulate_for_sbi
from sbi.inference import SNLE, SNPE
from sbi.simulators import linear_gaussian
from sbi.utils import BoxUniform, MultipleIndependent
from sbi.utils.user_input_checks import process_prior, process_simulator
from tests.test_utils import PosteriorPotential, TractablePosterior


Expand Down Expand Up @@ -51,9 +50,8 @@ def simulator(theta):

inferer = method(prior, show_progress_bars=False, density_estimator=model)

prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)
theta, x = simulate_for_sbi(simulator, prior, num_simulations)
theta = prior.sample((num_simulations,))
x = simulator(theta)

_ = inferer.append_simulations(theta, x).train(
training_batch_size=100, max_num_epochs=max_num_epochs
Expand Down Expand Up @@ -106,7 +104,8 @@ def simulator(theta):

inferer = method(prior, show_progress_bars=False, density_estimator=model)

theta, x = simulate_for_sbi(simulator, prior, num_simulations)
theta = prior.sample((num_simulations,))
x = simulator(theta)

_ = inferer.append_simulations(theta, x).train(
training_batch_size=100, max_num_epochs=max_num_epochs
Expand Down
5 changes: 3 additions & 2 deletions tests/tarp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
check_tarp,
run_tarp,
)
from sbi.inference import SNPE, simulate_for_sbi
from sbi.inference import SNPE
from sbi.simulators import linear_gaussian
from sbi.utils import BoxUniform
from sbi.utils.metrics import l1, l2
Expand Down Expand Up @@ -312,7 +312,8 @@ def simulator(theta):

inferer = method(prior, show_progress_bars=False, density_estimator=model)

theta, x = simulate_for_sbi(simulator, prior, num_simulations)
theta = prior.sample((num_simulations,))
x = simulator(theta)

_ = inferer.append_simulations(theta, x).train(
training_batch_size=100, max_num_epochs=max_num_epochs
Expand Down

0 comments on commit b275448

Please sign in to comment.