Skip to content

Commit

Permalink
fix: slow MNLE tests and warning formatting (#1211)
Browse files Browse the repository at this point in the history
* fix: more sims for MNLE slow test

* refactor: improve warning clarity and formatting
  • Loading branch information
janfb authored Aug 6, 2024
1 parent f9ec0bd commit 724ea4f
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 74 deletions.
27 changes: 13 additions & 14 deletions sbi/diagnostics/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def run_sbc(

if num_sbc_samples < 100:
warnings.warn(
"""Number of SBC samples should be on the order of 100s to give realiable
results.""",
"Number of SBC samples should be on the order of 100s to give realiable "
"results.",
stacklevel=2,
)
if num_posterior_samples < 100:
warnings.warn(
"""Number of posterior samples for ranking should be on the order
of 100s to give reliable SBC results.""",
"Number of posterior samples for ranking should be on the order "
"of 100s to give reliable SBC results.",
stacklevel=2,
)

Expand All @@ -73,8 +73,8 @@ def run_sbc(

if "sbc_batch_size" in kwargs:
warnings.warn(
"""`sbc_batch_size` is deprecated and will be removed in future versions.
Use `num_workers` instead.""",
"`sbc_batch_size` is deprecated and will be removed in future versions."
" Use `num_workers` instead.",
DeprecationWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -182,8 +182,8 @@ def get_nltp(thetas: Tensor, xs: Tensor, posterior: NeuralPosterior) -> Tensor:

if unnormalized_log_prob:
warnings.warn(
"""Note that log probs of the true parameters under the posteriors
are not normalized because the posterior used is likelihood-based.""",
"Note that log probs of the true parameters under the posteriors are not "
"normalized because the posterior used is likelihood-based.",
stacklevel=2,
)

Expand Down Expand Up @@ -216,9 +216,9 @@ def check_sbc(
"""
if ranks.shape[0] < 100:
warnings.warn(
"""You are computing SBC checks with less than 100 samples. These checks
should be based on a large number of test samples theta_o, x_o. We
recommend using at least 100.""",
"You are computing SBC checks with less than 100 samples. These checks"
" should be based on a large number of test samples theta_o, x_o. We"
" recommend using at least 100.",
stacklevel=2,
)

Expand Down Expand Up @@ -315,9 +315,8 @@ def check_uniformity_c2st(
c2st_std = c2st_scores.std(0, correction=0 if num_repetitions == 1 else 1)
if (c2st_std > 0.05).any():
warnings.warn(
f"""C2ST score variability is larger than {0.05}: std={c2st_scores.std(0)},
result may be unreliable. Consider increasing the number of samples.
""",
f"C2ST score variability is larger than {0.05}: std={c2st_scores.std(0)}, "
"result may be unreliable. Consider increasing the number of samples.",
stacklevel=2,
)

Expand Down
40 changes: 19 additions & 21 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def __init__(

if init_strategy_num_candidates is not None:
warn(
"""Passing `init_strategy_num_candidates` is deprecated as of sbi
v0.19.0. Instead, use e.g.,
`init_strategy_parameters={"num_candidate_samples": 1000}`""",
"Passing `init_strategy_num_candidates` is deprecated as of sbi "
"v0.19.0. Instead, use e.g., `init_strategy_parameters "
f"={'num_candidate_samples': 1000}`",
stacklevel=2,
)
self.init_strategy_parameters["num_candidate_samples"] = (
Expand Down Expand Up @@ -194,9 +194,8 @@ def log_prob(
`len($\theta$)`-shaped log-probability.
"""
warn(
"""`.log_prob()` is deprecated for methods that can only evaluate the
log-probability up to a normalizing constant. Use `.potential()`
instead.""",
"`.log_prob()` is deprecated for methods that can only evaluate the "
"log-probability up to a normalizing constant. Use `.potential()` instead.",
stacklevel=2,
)
warn("The log-probability is unnormalized!", stacklevel=2)
Expand Down Expand Up @@ -264,9 +263,9 @@ def sample(
)
if init_strategy_num_candidates is not None:
warn(
"""Passing `init_strategy_num_candidates` is deprecated as of sbi
v0.19.0. Instead, use e.g.,
`init_strategy_parameters={"num_candidate_samples": 1000}`""",
"Passing `init_strategy_num_candidates` is deprecated as of sbi "
"v0.19.0. Instead, use e.g., "
f"`init_strategy_parameters={'num_candidate_samples': 1000}`",
stacklevel=2,
)
self.init_strategy_parameters["num_candidate_samples"] = (
Expand All @@ -275,7 +274,7 @@ def sample(
if sample_with is not None:
raise ValueError(
f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting "
f"`sample_with` is no longer supported. You have to rerun "
"`sample_with` is no longer supported. You have to rerun "
f"`.build_posterior(sample_with={sample_with}).`"
)
if mcmc_method is not None:
Expand Down Expand Up @@ -426,9 +425,9 @@ def sample_batched(
# warn if num_chains is larger than num requested samples
if num_chains > torch.Size(sample_shape).numel():
warnings.warn(
f"""Passed num_chains {num_chains} is larger than the number of
requested samples {torch.Size(sample_shape).numel()}, resetting
it to {torch.Size(sample_shape).numel()}.""",
"The passed number of MCMC chains is larger than the number of "
f"requested samples: {num_chains} > {torch.Size(sample_shape).numel()},"
f" resetting it to {torch.Size(sample_shape).numel()}.",
stacklevel=2,
)
num_chains = torch.Size(sample_shape).numel()
Expand All @@ -453,12 +452,11 @@ def sample_batched(
num_chains_extended = batch_size * num_chains
if num_chains_extended > 100:
warnings.warn(
f"""Note that for batched sampling, we use {num_chains} for each
x in the batch. With the given settings, this results in a
large number of chains ({num_chains_extended}), This can be
large number of chains ({num_chains_extended}), which can be
slow and memory-intensive. Consider reducing the number of
chains.""",
"Note that for batched sampling, we use num_chains many chains for each"
" x in the batch. With the given settings, this results in a large "
f"number large number of chains ({num_chains_extended}), which can be "
"slow and memory-intensive for vectorized MCMC. Consider reducing the "
"number of chains.",
stacklevel=2,
)
init_strategy_parameters["num_return_samples"] = num_chains_extended
Expand Down Expand Up @@ -905,8 +903,8 @@ def _prepare_potential(self, method: str) -> Callable:
else:
if "hmc" in method or "nuts" in method:
warn(
"""The kwargs "hmc" and "nuts" are deprecated. Use "hmc_pyro",
"nuts_pyro", "hmc_pymc", or "nuts_pymc" instead.""",
"The kwargs 'hmc' and 'nuts' are deprecated. Use 'hmc_pyro', "
"'nuts_pyro', 'hmc_pymc', or 'nuts_pymc' instead.",
DeprecationWarning,
stacklevel=2,
)
Expand Down
5 changes: 2 additions & 3 deletions sbi/inference/posteriors/rejection_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,8 @@ def log_prob(
`len($\theta$)`-shaped log-probability.
"""
warn(
"""`.log_prob()` is deprecated for methods that can only evaluate the
log-probability up to a normalizing constant. Use `.potential()`
instead.""",
"`.log_prob()` is deprecated for methods that can only evaluate the "
"log-probability up to a normalizing constant. Use `.potential()` instead.",
stacklevel=2,
)
warn("The log-probability is unnormalized!", stacklevel=2)
Expand Down
8 changes: 4 additions & 4 deletions sbi/neural_nets/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ def build_mnle(
check_data_device(batch_x, batch_y)

warnings.warn(
"""The mixed neural likelihood estimator assumes that x contains
continuous data in the first n-1 columns (e.g., reaction times) and
categorical data in the last column (e.g., corresponding choices). If
this is not the case for the passed `x` do not use this function.""",
"The mixed neural likelihood estimator assumes that x contains "
"continuous data in the first n-1 columns (e.g., reaction times) and "
"categorical data in the last column (e.g., corresponding choices). If "
"this is not the case for the passed `x` do not use this function.",
stacklevel=2,
)
# Separate continuous and discrete data.
Expand Down
4 changes: 2 additions & 2 deletions sbi/samplers/mcmc/slice_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ def __init__(
# TODO: implement parallelization across batches of chains.
if num_workers > 1:
warn(
"""Parallelization of vectorized slice sampling not implement, running
serially.""",
"Parallelization of vectorized slice sampling not implement, running "
"serially.",
stacklevel=2,
)
self._reset()
Expand Down
42 changes: 21 additions & 21 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) -
# Check we do have different data in the batch
if num_unique == 1:
warnings.warn(
"""Beware that there is only a single unique element in the simulated data.
If this is intended, make sure to set `z_score_x='none'` as z-scoring would
result in NaNs""",
"Beware that there is only a single unique element in the simulated data. "
"If this is intended, make sure to set `z_score_x='none'` as z-scoring "
"would result in NaNs",
UserWarning,
stacklevel=2,
)
Expand All @@ -61,13 +61,14 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) -

if num_unique_z < num_unique * (1 - duplicate_tolerance):
warnings.warn(
"""Z-scoring these simulation outputs resulted in {num_unique_z} unique
datapoints. Before z-scoring, it had been {num_unique}. This can occur
due to numerical inaccuracies when the data covers a large range of
values. Consider either setting `z_score_x=False` (but beware that this
can be problematic for training the NN) or exclude outliers from your
dataset. Note: if you have already set `z_score_x=False`, this warning
will still be displayed, but you can ignore it.""",
"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
"datapoints. Before z-scoring, it had been {num_unique}. This can "
"occur due to numerical inaccuracies when the data covers a large "
"range of values. Consider either setting `z_score_x=False` (but "
"beware that this can be problematic for training the NN) or exclude "
"outliers from your dataset. Note: if you have already set "
"`z_score_x=False`, this warning will still be displayed, but you can"
" ignore it.",
UserWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -406,11 +407,11 @@ def warn_on_batched_x(batch_size):
if batch_size > 1:
warnings.warn(
f"An x with a batch size of {batch_size} was passed. "
+ """Unless you are using `sample_batched` or `log_prob_batched`, this will
be interpreted as a batch of independent and identically distributed data
X={x_1, ..., x_n}, i.e., data generated based on the same underlying
(unknown) parameter. The resulting posterior will be with respect to entire
batch, i.e,. p(theta | X).""",
"Unless you are using `sample_batched` or `log_prob_batched`, this will "
"be interpreted as a batch of independent and identically distributed data"
" X={x_1, ..., x_n}, i.e., data generated based on the same underlying"
"(unknown) parameter. The resulting posterior will be with respect to"
" the entire batch, i.e,. p(theta | X).",
stacklevel=2,
)

Expand Down Expand Up @@ -714,9 +715,9 @@ def mcmc_transform(
# does not implement support.
# AttributeError -> Custom distribution that has no support attribute.
warnings.warn(
"""The passed prior has no support property, transform will be
constructed from mean and std. If the passed prior is supposed to be
bounded consider implementing the prior.support property.""",
"The passed prior has no support property, transform will be "
"constructed from mean and std. If the passed prior is supposed to be "
"bounded consider implementing the prior.support property.",
stacklevel=2,
)
has_support = False
Expand Down Expand Up @@ -749,9 +750,8 @@ def mcmc_transform(
# does not implement mean, e.g., TransformedDistribution.
# AttributeError -> Custom distribution that has no mean/std attribute.
warnings.warn(
"""The passed prior has no mean or stddev attribute, estimating
them from samples to build affimed standardizing
transform.""",
"The passed prior has no mean or stddev attribute, estimating "
"them from samples to build affimed standardizing transform.",
stacklevel=2,
)
theta = prior.sample(torch.Size((num_prior_samples_for_zscoring,)))
Expand Down
6 changes: 3 additions & 3 deletions sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def process_prior(
# If prior is a sequence, assume independent components and check as PyTorch prior.
if isinstance(prior, Sequence):
warnings.warn(
f"""Prior was provided as a sequence of {len(prior)} priors. They will be
interpreted as independent of each other and matched in order to the
components of the parameter.""",
f"Prior was provided as a sequence of {len(prior)} priors. They will be "
"interpreted as independent of each other and matched in order to the "
"components of the parameter.",
stacklevel=2,
)
# process individual priors
Expand Down
8 changes: 4 additions & 4 deletions sbi/utils/user_input_checks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def _set_mean_and_variance(self):
** 2
)
warnings.warn(
"""Prior is lacking variance attribute, estimating prior variance from
samples...""",
"Prior is lacking variance attribute, estimating prior variance from "
"samples.",
UserWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -333,8 +333,8 @@ def build_support(
if lower_bound is None and upper_bound is None:
support = constraints.real
warnings.warn(
"""No prior bounds were passed, consider passing lower_bound
and / or upper_bound if your prior has bounded support.""",
"No prior bounds were passed, consider passing lower_bound "
"and / or upper_bound if your prior has bounded support.",
stacklevel=2,
)
# Only lower bound is specified.
Expand Down
4 changes: 2 additions & 2 deletions tests/mnle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_mnle_accuracy_with_different_samplers_and_trials(
):
"""Test MNLE c2st accuracy for different samplers and number of trials."""

num_simulations = 3000
num_simulations = 3200
num_samples = 500

prior = MultipleIndependent(
Expand All @@ -152,7 +152,7 @@ def test_mnle_accuracy_with_different_samplers_and_trials(
# MNLE
density_estimator = likelihood_nn(model="mnle", flow_model=flow_model)
trainer = MNLE(prior, density_estimator=density_estimator)
trainer.append_simulations(theta, x).train()
trainer.append_simulations(theta, x).train(training_batch_size=200)
posterior = trainer.build_posterior()

theta_o = prior.sample((1,))
Expand Down

0 comments on commit 724ea4f

Please sign in to comment.