5
5
6
6
from sbi .inference .posteriors .base_posterior import NeuralPosterior
7
7
from sbi .inference .posteriors .vi_posterior import VIPosterior
8
+ from sbi .sbi_types import Shape
8
9
9
10
10
11
def get_posterior_samples_on_batch (
11
12
xs : Tensor ,
12
13
posterior : NeuralPosterior ,
13
- num_samples : int ,
14
+ sample_shape : Shape ,
14
15
num_workers : int = 1 ,
15
16
show_progress_bar : bool = False ,
17
+ use_batched_sampling : bool = True ,
16
18
) -> Tensor :
17
19
"""Get posterior samples for a batch of xs.
18
20
19
21
Args:
20
22
xs: batch of observations.
21
23
posterior: sbi posterior.
22
- num_posterior_samples: number of samples to draw from the posterior in each sbc
23
- run.
24
+ num_samples: number of samples to draw from the posterior for each x.
24
25
num_workers: number of workers to use for parallelization.
25
26
show_progress_bars: whether to show progress bars.
27
+ use_batched_sampling: whether to use batched sampling if possible.
26
28
27
29
Returns:
28
30
posterior_samples: of shape (num_samples, batch_size, dim_parameters).
@@ -32,35 +34,44 @@ def get_posterior_samples_on_batch(
32
34
# Try using batched sampling when implemented.
33
35
try :
34
36
# has shape (num_samples, batch_size, dim_parameters)
35
- posterior_samples = posterior .sample_batched (
36
- (num_samples ,), xs , show_progress_bars = show_progress_bar
37
- )
37
+ if use_batched_sampling :
38
+ posterior_samples = posterior .sample_batched (
39
+ sample_shape , x = xs , show_progress_bars = show_progress_bar
40
+ )
41
+ else :
42
+ raise NotImplementedError
38
43
except NotImplementedError :
39
44
# We need a function with extra training step for new x for VIPosterior.
40
- def sample_fun (posterior : NeuralPosterior , sample_shape , x : Tensor ) -> Tensor :
45
+ def sample_fun (
46
+ posterior : NeuralPosterior , sample_shape : Shape , x : Tensor , seed : int = 0
47
+ ) -> Tensor :
41
48
if isinstance (posterior , VIPosterior ):
42
49
posterior .set_default_x (x )
43
50
posterior .train ()
51
+ torch .manual_seed (seed )
44
52
return posterior .sample (sample_shape , x = x , show_progress_bars = False )
45
53
46
54
# Run in parallel with progress bar.
55
+ seeds = torch .randint (0 , 2 ** 32 , (batch_size ,))
47
56
outputs = list (
48
57
tqdm (
49
58
Parallel (return_as = "generator" , n_jobs = num_workers )(
50
- delayed (sample_fun )(posterior , (num_samples ,), x = x ) for x in xs
59
+ delayed (sample_fun )(posterior , sample_shape , x = x , seed = s )
60
+ for x , s in zip (xs , seeds )
51
61
),
52
62
disable = not show_progress_bar ,
53
63
total = len (xs ),
54
- desc = f"Sampling { batch_size } times { num_samples } posterior samples." ,
64
+ desc = f"Sampling { batch_size } times { sample_shape } posterior samples." ,
55
65
)
56
- )
57
- # Transpose to sample_batched shape convention:
58
- posterior_samples = torch .stack (outputs ).transpose (0 , 1 ) # type: ignore
66
+ ) # (batch_size, num_samples, dim_parameters)
67
+ # Transpose to shape convention: (sample_shape, batch_size, dim_parameters)
68
+ posterior_samples = torch .stack (
69
+ outputs # type: ignore
70
+ ).permute (1 , 0 , 2 )
59
71
60
- assert posterior_samples .shape [:2 ] == (
61
- num_samples ,
72
+ assert posterior_samples .shape [:2 ] == sample_shape + (
62
73
batch_size ,
63
- ), f"""Expected batched posterior samples of shape { ( num_samples , batch_size ) } got {
64
- posterior_samples . shape [: 2 ]
65
- } ."""
74
+ ), f"""Expected batched posterior samples of shape {
75
+ sample_shape + ( batch_size ,)
76
+ } got { posterior_samples . shape [: 2 ] } ."""
66
77
return posterior_samples
0 commit comments