Skip to content

Commit

Permalink
GPU compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler committed Jun 11, 2024
1 parent 7d4eb55 commit f53e1ec
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions sbi/samplers/rejection/rejection.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ def accept_reject_sample(
proposal_sampling_kwargs = {}

num_remaining = num_samples

# NOTE: We might want to change this to a more general approach in the future.
# Currently limited to a single "batch_dim" for the condition.
# But this would require giving the method the condition_shape explicitly...
if "condition" in proposal_sampling_kwargs:
num_xos = proposal_sampling_kwargs["condition"].shape[0]
else:
Expand Down Expand Up @@ -295,7 +299,7 @@ def accept_reject_sample(
# samples will be of shape(sampling_batch_size,*batch_shape, *event_shape)
# and hence work in dim = 0.
num_accepted = are_accepted.sum(dim=0)
num_sampled_total += num_accepted
num_sampled_total += num_accepted.to(num_sampled_total.device)
num_samples_possible += sampling_batch_size
min_num_accepted = num_accepted.min().item()
num_remaining -= min_num_accepted
Expand Down Expand Up @@ -358,4 +362,4 @@ def accept_reject_sample(
samples.shape[0] == num_samples
), "Number of accepted samples must match required samples."

return samples, as_tensor(min_acceptance_rate)
return samples, acceptance_rate

0 comments on commit f53e1ec

Please sign in to comment.