Skip to content

Commit

Permalink
fix: joblib not saturating CPU during multiprocessing (#1188)
Browse files Browse the repository at this point in the history
* refactoring simulate_for_sbi

* refactored simulate_for_sbi, introduced new wrapper wrap_as_joblib_efficient_simulator

* Finished refactoring simulate_for_sbi and wrap_as_joblib_efficient_simulator. The wrapping/casting currently increases the runtime roughly three times, but the code cannot be breaking for now.

* working on user_input_checks

* added temporary benchmark folder

* adding process_simulator / process_prior to tests

* added process_simulator and process_prior to some of the tests

* finished adding process_simulator and process_prior to\ndefault git tests (pytest -n auto -m "not slow\nand not gpu")

* Changes following PR 1188: removed `if-else` for `show_progress_bar` in `simulate_for_sbi`, improved comments, removed `benchmark` folder

* restructured simulation_batch_size logic in simulate_for_sbi according to #1188 discussion

* Bypassed `process_simulator` and `simulate_for_sbi`  in tests/inference_on_device_test.py

#1188

Co-authored-by: Jan <[email protected]>

* adjusted imports and formatting

* changed `|` to `Union` in `inference/base.py::simulate_for_sbi`

---------

Co-authored-by: Janko Petkovic <[email protected]>
Co-authored-by: Jan <[email protected]>
  • Loading branch information
3 people authored Jul 22, 2024
1 parent ba19688 commit 83e122a
Show file tree
Hide file tree
Showing 15 changed files with 153 additions and 26 deletions.
84 changes: 67 additions & 17 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union
from warnings import warn

import numpy as np
import torch
from torch import Tensor
from joblib import Parallel, delayed
from numpy import ndarray
from torch import Tensor, float32
from torch.distributions import Distribution
from torch.utils import data
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm.auto import tqdm

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.simulators.simutils import simulate_in_batches
from sbi.utils import (
check_prior,
get_log_root,
Expand All @@ -26,7 +29,7 @@
validate_theta_and_x,
warn_if_zscoring_changes_data,
)
from sbi.utils.sbiutils import get_simulations_since_round
from sbi.utils.sbiutils import get_simulations_since_round, seed_all_backends
from sbi.utils.torchutils import check_if_prior_on_device, process_device
from sbi.utils.user_input_checks import (
check_sbi_inputs,
Expand Down Expand Up @@ -565,12 +568,17 @@ def __setstate__(self, state_dict: Dict):
self.__dict__ = state_dict


# Refactoring following #1175. tl:dr: letting joblib iterate over numpy arrays
# allows for a roughly 10x performance gain. The resulting casting necessity
# (cfr. user_input_checks.wrap_as_joblib_efficient_simulator) introduces
# considerable overhead. The simulating pipeline should, therefore, be further
# restructured in the future (PR #1188).
def simulate_for_sbi(
simulator: Callable,
proposal: Any,
num_simulations: int,
num_workers: int = 1,
simulation_batch_size: int = 1,
simulation_batch_size: Union[int, None] = 1,
seed: Optional[int] = None,
show_progress_bar: bool = True,
) -> Tuple[Tensor, Tensor]:
Expand All @@ -590,10 +598,12 @@ def simulate_for_sbi(
from.
num_simulations: Number of simulations that are run.
num_workers: Number of parallel workers to use for simulations.
simulation_batch_size: Number of parameter sets that the simulator
maps to data x at once. If None, we simulate all parameter sets at the
same time. If >= 1, the simulator has to process data of shape
(simulation_batch_size, parameter_dimension).
simulation_batch_size: Number of parameter sets of shape
(simulation_batch_size, parameter_dimension) that the simulator
receives per call. If None, we set
simulation_batch_size=num_simulations and simulate all parameter
sets with one call. Otherwise, we construct batches of parameter
sets and distribute them among num_workers.
seed: Seed for reproducibility.
show_progress_bar: Whether to show a progress bar for simulating. This will not
affect whether there will be a progressbar while drawing samples from the
Expand All @@ -602,16 +612,56 @@ def simulate_for_sbi(
Returns: Sampled parameters $\theta$ and simulation-outputs $x$.
"""

theta = proposal.sample((num_simulations,))
if num_simulations == 0:
theta = torch.tensor([], dtype=float32)
x = torch.tensor([], dtype=float32)

x = simulate_in_batches(
simulator=simulator,
theta=theta,
sim_batch_size=simulation_batch_size,
num_workers=num_workers,
seed=seed,
show_progress_bars=show_progress_bar,
)
else:
# Cast theta to numpy for better joblib performance (seee #1175)
seed_all_backends(seed)
theta = proposal.sample((num_simulations,)).numpy()

# Parse the simulation_batch_size logic
if simulation_batch_size is None:
simulation_batch_size = num_simulations
else:
simulation_batch_size = min(simulation_batch_size, num_simulations)

# The batch size will be an approximation, since np.array_split does
# not take as argument the size of the batch but their total.
num_batches = num_simulations // simulation_batch_size

batches = np.array_split(theta, num_batches, axis=0)

if num_workers != 1:
batch_seeds = np.random.randint(low=0, high=1_000_000, size=(len(batches),))

# define seeded simulator.
def simulator_seeded(theta: ndarray, seed: int) -> Tensor:
seed_all_backends(seed)
return simulator(theta)

simulation_outputs: list[Tensor] = [ # pyright: ignore
xx
for xx in tqdm(
Parallel(return_as="generator", n_jobs=num_workers)(
delayed(simulator_seeded)(batch, seed)
for batch, seed in zip(batches, batch_seeds)
),
total=num_simulations,
disable=not show_progress_bar,
)
]

else:
simulation_outputs: list[Tensor] = []

for batch in tqdm(batches, disable=not show_progress_bar):
simulation_outputs.append(simulator(batch))

# Correctly format the output
x = torch.cat(simulation_outputs, dim=0)
theta = torch.as_tensor(theta, dtype=float32)

return theta, x

Expand Down
21 changes: 19 additions & 2 deletions sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,15 +462,32 @@ def process_simulator(

assert isinstance(user_simulator, Callable), "Simulator must be a function."

pytorch_simulator = wrap_as_pytorch_simulator(
joblib_simulator = wrap_as_joblib_efficient_simulator(
user_simulator, prior, is_numpy_simulator
)

batch_simulator = ensure_batched_simulator(pytorch_simulator, prior)
batch_simulator = ensure_batched_simulator(joblib_simulator, prior)

return batch_simulator


# New simulator wrapper, deriving from #1175 refactoring of simulate_for_sbi.
# For now it just blindly applies a cast to tensor to the input and the output
# of the simulator. This is not efficient (~3 times slowdown), but is compatible
# with the new joblib and, importantly, does not break previous code. It should
# be removed with a future restructuring of the simulation pipeline.
def wrap_as_joblib_efficient_simulator(
simulator: Callable, prior, is_numpy_simulator
) -> Callable:
"""Return a simulator that accepts `ndarray` and returns `Tensor` arguments."""

def joblib_simulator(theta: ndarray) -> Tensor:
return torch.as_tensor(simulator(torch.as_tensor(theta)), dtype=float32)

return joblib_simulator


# Probably not used anymore
def wrap_as_pytorch_simulator(
simulator: Callable, prior, is_numpy_simulator
) -> Callable:
Expand Down
5 changes: 3 additions & 2 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
VIPosterior,
likelihood_estimator_based_potential,
ratio_estimator_based_potential,
simulate_for_sbi,
)
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
from sbi.inference.potentials.base_potential import BasePotential
Expand Down Expand Up @@ -345,7 +344,9 @@ def test_nograd_after_inference_train(inference_method) -> None:
show_progress_bars=False,
)

theta, x = simulate_for_sbi(simulator, prior, 32)
num_simulations = 32
theta = prior.sample((num_simulations,))
x = simulator(theta)
inference = inference.append_simulations(theta, x)

posterior_estimator = inference.train(max_num_epochs=2)
Expand Down
4 changes: 4 additions & 0 deletions tests/inference_with_NaN_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sbi.utils.sbiutils import handle_invalid_x
from sbi.utils.user_input_checks import (
check_sbi_inputs,
process_prior,
process_simulator,
)

Expand Down Expand Up @@ -187,6 +188,9 @@ def simulator(theta):
prior = utils.BoxUniform(-2 * torch.ones(2), 2 * torch.ones(2))
else:
prior = MultivariateNormal(torch.zeros(2), torch.eye(2))

prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)
theta, x = simulate_for_sbi(simulator, prior, 1000)

restriction_estimator = RestrictionEstimator(prior=prior)
Expand Down
5 changes: 5 additions & 0 deletions tests/linearGaussian_mdn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
linear_gaussian,
true_posterior_linear_gaussian_mvn_prior,
)
from sbi.utils.user_input_checks import process_prior, process_simulator
from tests.test_utils import check_c2st


Expand Down Expand Up @@ -49,6 +50,8 @@ def simulator(theta: Tensor) -> Tensor:

inference = method(density_estimator="mdn")

prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)
theta, x = simulate_for_sbi(simulator, prior, num_simulations)
estimator = inference.append_simulations(theta, x).train()
if method == SNPE:
Expand Down Expand Up @@ -94,6 +97,8 @@ def simulator(theta: Tensor) -> Tensor:

inference = SNPE(density_estimator="mdn")

prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)
theta, x = simulate_for_sbi(simulator, prior, 100)
posterior_estimator = inference.append_simulations(theta, x).train()
posterior = DirectPosterior(posterior_estimator=posterior_estimator, prior=prior)
Expand Down
11 changes: 10 additions & 1 deletion tests/linearGaussian_snle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sbi.utils import BoxUniform
from sbi.utils.user_input_checks import (
process_prior,
process_simulator,
)

from .test_utils import check_c2st, get_prob_outside_uniform_prior
Expand All @@ -51,6 +52,8 @@ def test_api_snle_multiple_trials_and_rounds_map(
prior = BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim))

simulator = diagonal_linear_gaussian
prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)
inference = SNLE(prior=prior, density_estimator="mdn", show_progress_bars=False)

proposals = [prior]
Expand Down Expand Up @@ -116,8 +119,14 @@ def simulator(theta):
density_estimator = likelihood_nn(model=model_str, num_transforms=3)
inference = SNLE(density_estimator=density_estimator, show_progress_bars=False)

prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)
theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=num_simulations
simulator,
prior,
num_simulations,
simulation_batch_size=num_simulations,
seed=1,
)
likelihood_estimator = inference.append_simulations(theta, x).train()
potential_fn, theta_transform = likelihood_estimator_based_potential(
Expand Down
13 changes: 13 additions & 0 deletions tests/linearGaussian_snpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
true_posterior_linear_gaussian_mvn_prior,
)
from sbi.utils import RestrictedPrior, get_density_thresholder
from sbi.utils.user_input_checks import process_prior, process_simulator

from .sbiutils_test import conditional_of_mvn
from .test_utils import (
Expand Down Expand Up @@ -80,6 +81,9 @@ def simulator(theta):

inference = snpe_method(prior, show_progress_bars=False)

prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)

theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=1000
)
Expand Down Expand Up @@ -235,6 +239,8 @@ def simulator(theta):
)

# type: ignore
prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)
theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=num_simulations
)
Expand Down Expand Up @@ -477,6 +483,8 @@ def simulator(theta):
return linear_gaussian(theta, likelihood_shift, likelihood_cov)

inference = SNPE_C(prior, show_progress_bars=False)
prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)

proposal = prior
for _ in range(2):
Expand Down Expand Up @@ -654,6 +662,9 @@ def simulator(theta):

inference = SNPE_C(density_estimator="mdn", show_progress_bars=False)

prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)

theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=1000
)
Expand Down Expand Up @@ -692,6 +703,8 @@ def simulator(theta):
return linear_gaussian(theta, likelihood_shift, likelihood_cov)

inference = snpe_method(prior, show_progress_bars=False)
prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)

theta, x = simulate_for_sbi(
simulator,
Expand Down
7 changes: 7 additions & 0 deletions tests/linearGaussian_snre_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
samples_true_posterior_linear_gaussian_uniform_prior,
true_posterior_linear_gaussian_mvn_prior,
)
from sbi.utils.user_input_checks import process_prior, process_simulator
from tests.test_utils import (
check_c2st,
get_dkl_gaussian_prior,
Expand All @@ -53,6 +54,9 @@ def test_api_snre_multiple_trials_and_rounds_map(
simulator = diagonal_linear_gaussian
inference = snre_method(prior=prior, classifier="mlp", show_progress_bars=False)

prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)

proposals = [prior]
for _ in range(num_rounds):
theta, x = simulate_for_sbi(
Expand Down Expand Up @@ -109,6 +113,9 @@ def simulator(theta):

inference = snre_method(classifier="resnet", show_progress_bars=False)

prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)

theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=100
)
Expand Down
5 changes: 3 additions & 2 deletions tests/mcmc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
diagonal_linear_gaussian,
true_posterior_linear_gaussian_mvn_prior,
)
from sbi.utils.user_input_checks import process_prior
from sbi.utils.user_input_checks import process_prior, process_simulator
from tests.test_utils import check_c2st


Expand Down Expand Up @@ -202,10 +202,11 @@ def test_getting_inference_diagnostics(method, mcmc_params_fast: dict):
Uniform(low=-ones(1), high=ones(1)),
]

prior, _, _ = process_prior(prior)
simulator = diagonal_linear_gaussian
density_estimator = likelihood_nn("maf", num_transforms=3)
inference = SNLE(density_estimator=density_estimator, show_progress_bars=False)
prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)

theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=num_simulations
Expand Down
2 changes: 1 addition & 1 deletion tests/multiprocessing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def slow_linear_gaussian(theta):


@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [2])
@pytest.mark.parametrize("num_workers", [4])
@pytest.mark.parametrize("sim_batch_size", ((1, 10, 100)))
def test_benchmarking_parallel_simulation(sim_batch_size, num_workers):
"""Test whether joblib is faster than serial processing."""
Expand Down
5 changes: 5 additions & 0 deletions tests/plot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sbi.analysis import pairplot, plot_summary, sbc_rank_plot
from sbi.inference import SNLE, SNPE, SNRE, simulate_for_sbi
from sbi.utils import BoxUniform
from sbi.utils.user_input_checks import process_prior, process_simulator


@pytest.mark.parametrize("samples", (torch.randn(100, 1),))
Expand Down Expand Up @@ -105,6 +106,10 @@ def simulator(theta):
return theta + 1.0 + torch.randn_like(theta) * 0.1

inference = method(prior=prior, summary_writer=summary_writer)

prior, _, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)

theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=6)
train_kwargs = (
dict(max_num_epochs=5, validation_fraction=0.5, num_atoms=2)
Expand Down
Loading

0 comments on commit 83e122a

Please sign in to comment.