Skip to content

Commit

Permalink
Fix blackjax inference loop for newest blackjax.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681005933
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Oct 1, 2024
1 parent 9553f12 commit 7fc5a2d
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions bayeux/_src/mcmc/blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,16 @@ def _blackjax_inference(
# return from `run_inference_algorithm` changes from
# `_, states, infos` to `_, (states, infos)`. This one weird
# trick handles both cases.
_, *states_and_infos = blackjax.util.run_inference_algorithm(
ret = blackjax.util.run_inference_algorithm(
rng_key=seed,
inference_algorithm=inference_algorithm,
num_steps=num_draws,
progress_bar=False,
**{_INFERENCE_KWARG: adapt_state})
return states_and_infos
if len(ret) == 2: # For newer blackjax versions (1.2.4+)
return ret[1]
else: # Delete this once blackjax 1.2.4 is stable
return ret[1:]


def _blackjax_inference_loop(
Expand Down

0 comments on commit 7fc5a2d

Please sign in to comment.