Skip to content

Commit ca2db69

Browse files
committed
refactor: update batch size and mcmc defaults.
1 parent b3254ed commit ca2db69

File tree

15 files changed

+21
-21
lines changed

15 files changed

+21
-21
lines changed

sbi/analysis/sensitivity_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def build_mlp(theta):
228228

229229
def train(
230230
self,
231-
training_batch_size: int = 50,
231+
training_batch_size: int = 200,
232232
learning_rate: float = 5e-4,
233233
validation_fraction: float = 0.1,
234234
stop_after_epochs: int = 20,

sbi/inference/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def append_simulations(
295295
@abstractmethod
296296
def train(
297297
self,
298-
training_batch_size: int = 50,
298+
training_batch_size: int = 200,
299299
learning_rate: float = 5e-4,
300300
validation_fraction: float = 0.1,
301301
stop_after_epochs: int = 20,
@@ -312,7 +312,7 @@ def train(
312312
def get_dataloaders(
313313
self,
314314
starting_round: int = 0,
315-
training_batch_size: int = 50,
315+
training_batch_size: int = 200,
316316
validation_fraction: float = 0.1,
317317
resume_training: bool = False,
318318
dataloader_kwargs: Optional[dict] = None,

sbi/inference/fmpe/fmpe_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def append_simulations(
126126

127127
def train(
128128
self,
129-
training_batch_size: int = 50,
129+
training_batch_size: int = 200,
130130
learning_rate: float = 5e-4,
131131
validation_fraction: float = 0.1,
132132
stop_after_epochs: int = 20,

sbi/inference/posteriors/mcmc_posterior.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def __init__(
4747
potential_fn: Union[Callable, BasePotential],
4848
proposal: Any,
4949
theta_transform: Optional[TorchTransform] = None,
50-
method: str = "slice_np",
50+
method: str = "slice_np_vectorized",
5151
thin: int = -1,
5252
warmup_steps: int = 200,
53-
num_chains: int = 1,
53+
num_chains: int = 20,
5454
init_strategy: str = "resample",
5555
init_strategy_parameters: Optional[Dict[str, Any]] = None,
5656
init_strategy_num_candidates: Optional[int] = None,

sbi/inference/snle/mnle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666

6767
def train(
6868
self,
69-
training_batch_size: int = 50,
69+
training_batch_size: int = 200,
7070
learning_rate: float = 5e-4,
7171
validation_fraction: float = 0.1,
7272
stop_after_epochs: int = 20,
@@ -92,7 +92,7 @@ def build_posterior(
9292
density_estimator: Optional[TorchModule] = None,
9393
prior: Optional[Distribution] = None,
9494
sample_with: str = "mcmc",
95-
mcmc_method: str = "slice_np",
95+
mcmc_method: str = "slice_np_vectorized",
9696
vi_method: str = "rKL",
9797
mcmc_parameters: Optional[Dict[str, Any]] = None,
9898
vi_parameters: Optional[Dict[str, Any]] = None,

sbi/inference/snle/snle_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def append_simulations(
118118

119119
def train(
120120
self,
121-
training_batch_size: int = 50,
121+
training_batch_size: int = 200,
122122
learning_rate: float = 5e-4,
123123
validation_fraction: float = 0.1,
124124
stop_after_epochs: int = 20,
@@ -267,7 +267,7 @@ def build_posterior(
267267
density_estimator: Optional[ConditionalDensityEstimator] = None,
268268
prior: Optional[Distribution] = None,
269269
sample_with: str = "mcmc",
270-
mcmc_method: str = "slice_np",
270+
mcmc_method: str = "slice_np_vectorized",
271271
vi_method: str = "rKL",
272272
mcmc_parameters: Optional[Dict[str, Any]] = None,
273273
vi_parameters: Optional[Dict[str, Any]] = None,

sbi/inference/snpe/snpe_a.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(
103103
def train(
104104
self,
105105
final_round: bool = False,
106-
training_batch_size: int = 50,
106+
training_batch_size: int = 200,
107107
learning_rate: float = 5e-4,
108108
validation_fraction: float = 0.1,
109109
stop_after_epochs: int = 20,

sbi/inference/snpe/snpe_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def append_simulations(
209209

210210
def train(
211211
self,
212-
training_batch_size: int = 50,
212+
training_batch_size: int = 200,
213213
learning_rate: float = 5e-4,
214214
validation_fraction: float = 0.1,
215215
stop_after_epochs: int = 20,
@@ -435,7 +435,7 @@ def build_posterior(
435435
density_estimator: Optional[ConditionalDensityEstimator] = None,
436436
prior: Optional[Distribution] = None,
437437
sample_with: str = "direct",
438-
mcmc_method: str = "slice_np",
438+
mcmc_method: str = "slice_np_vectorized",
439439
vi_method: str = "rKL",
440440
direct_sampling_parameters: Optional[Dict[str, Any]] = None,
441441
mcmc_parameters: Optional[Dict[str, Any]] = None,

sbi/inference/snpe/snpe_c.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(
9090
def train(
9191
self,
9292
num_atoms: int = 10,
93-
training_batch_size: int = 50,
93+
training_batch_size: int = 200,
9494
learning_rate: float = 5e-4,
9595
validation_fraction: float = 0.1,
9696
stop_after_epochs: int = 20,

sbi/inference/snre/bnre.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
def train(
5757
self,
5858
regularization_strength: float = 100.0,
59-
training_batch_size: int = 50,
59+
training_batch_size: int = 200,
6060
learning_rate: float = 5e-4,
6161
validation_fraction: float = 0.1,
6262
stop_after_epochs: int = 20,

0 commit comments

Comments
 (0)