Skip to content

Commit

Permalink
hotfix to flatten_list_parameters (#327)
Browse files Browse the repository at this point in the history
* hotfix to flatten_list_parameters not working with jax array, adding tests

* fixing mypy
  • Loading branch information
arik-shurygin authored Jan 21, 2025
1 parent c752b46 commit 6bf6a79
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 6 deletions.
8 changes: 4 additions & 4 deletions src/dynode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,14 +1121,14 @@ def drop_sample_chains(samples: dict, dropped_chain_vals: list):


def flatten_list_parameters(
samples: dict[str, np.ndarray],
) -> dict[str, np.ndarray]:
samples: dict[str, np.ndarray | Array],
) -> dict[str, np.ndarray | Array]:
"""
Flatten plated parameters into separate keys in the samples dictionary.
Parameters
----------
samples : dict[str, np.ndarray]
samples : dict[str, np.ndarray | Array]
Dictionary with parameter names as keys and sample
arrays as values. Arrays may have shape MxNxP for P independent draws.
Expand All @@ -1144,7 +1144,7 @@ def flatten_list_parameters(
"""
return_dict = {}
for key, value in samples.items():
if isinstance(value, np.ndarray) and value.ndim > 2:
if isinstance(value, (np.ndarray, Array)) and value.ndim > 2:
num_dims = value.ndim - 2
indices = (
np.indices(value.shape[-num_dims:]).reshape(num_dims, -1).T
Expand Down
4 changes: 2 additions & 2 deletions src/dynode/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def plot_checkpoint_inference_correlation_pairs(
Figure with n rows and n columns where n is the number of sampled parameters.
"""
# convert lists to np.arrays
posteriors: dict[str, np.ndarray] = flatten_list_parameters(
posteriors: dict[str, np.ndarray | Array] = flatten_list_parameters(
{
key: np.array(val) if isinstance(val, list) else val
for key, val in posteriors_in.items()
Expand Down Expand Up @@ -408,7 +408,7 @@ def plot_mcmc_chains(
Matplotlib figure containing the plots.
"""
# Determine the number of parameters and chains
samples: dict[str, np.ndarray] = flatten_list_parameters(
samples: dict[str, np.ndarray | Array] = flatten_list_parameters(
{
key: np.array(val) if isinstance(val, list) else val
for key, val in samples_in.items()
Expand Down
55 changes: 55 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import IntEnum

import jax.numpy as jnp
import numpy as np
import numpyro.distributions as dist

from dynode import utils
Expand Down Expand Up @@ -426,3 +427,57 @@ def test_get_timeline_from_solution_with_command_compartment_slice():
assert jnp.all(
timeline == 16
) # Each element in sol is 1, summed over 4*1*1*4 = 16


def test_flatten_list_params_numpy():
# simulate 4 chains and 20 samples each with 4 plated parameters
testing = {"test": np.ones((4, 20, 5))}
flattened = utils.flatten_list_parameters(testing)
assert "test" not in flattened.keys()
for suffix in range(5):
key = "test_%s" % str(suffix)
assert (
key in flattened.keys()
), "flatten_list_parameters not naming split params correctly."
assert flattened[key].shape == (
4,
20,
), "flatten_list_parameters breaking up wrong axis"


def test_flatten_list_params_jax_numpy():
# simulate 4 chains and 20 samples each with 4 plated parameters
# this time with jax numpy instead of numpy
testing = {"test": jnp.ones((4, 20, 5))}
flattened = utils.flatten_list_parameters(testing)
assert "test" not in flattened.keys()
for suffix in range(5):
key = "test_%s" % str(suffix)
assert (
key in flattened.keys()
), "flatten_list_parameters not naming split params correctly."
assert flattened[key].shape == (
4,
20,
), "flatten_list_parameters breaking up wrong axis"


def test_flatten_list_params_multi_dim():
# simulate 4 chains and 20 samples each with 10 plated parameters
# this time with jax numpy instead of numpy
testing = {"test": jnp.ones((4, 20, 5, 2))}
flattened = utils.flatten_list_parameters(testing)
assert "test" not in flattened.keys()
for suffix_first_dim in range(5):
for suffix_second_dim in range(2):
key = "test_%s_%s" % (
str(suffix_first_dim),
str(suffix_second_dim),
)
assert (
key in flattened.keys()
), "flatten_list_parameters not naming split params correctly."
assert flattened[key].shape == (
4,
20,
), "flatten_list_parameters breaking up wrong axis when passed >3"

0 comments on commit 6bf6a79

Please sign in to comment.