From 6bf6a79380c8176a94e4b497b55f8e0483f27b2e Mon Sep 17 00:00:00 2001 From: Ariel Shurygin <39861882+arik-shurygin@users.noreply.github.com> Date: Tue, 21 Jan 2025 09:25:22 -0800 Subject: [PATCH] hotfix to flatten_list_parameters (#327) * hotfix to flatten_list_parameters not working with jax array, adding tests * fixing mypy --- src/dynode/utils.py | 8 +++--- src/dynode/vis_utils.py | 4 +-- tests/test_utils.py | 55 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 6 deletions(-) diff --git a/src/dynode/utils.py b/src/dynode/utils.py index 5bb1d5dd..566bc1d6 100644 --- a/src/dynode/utils.py +++ b/src/dynode/utils.py @@ -1121,14 +1121,14 @@ def drop_sample_chains(samples: dict, dropped_chain_vals: list): def flatten_list_parameters( - samples: dict[str, np.ndarray], -) -> dict[str, np.ndarray]: + samples: dict[str, np.ndarray | Array], +) -> dict[str, np.ndarray | Array]: """ Flatten plated parameters into separate keys in the samples dictionary. Parameters ---------- - samples : dict[str, np.ndarray] + samples : dict[str, np.ndarray | Array] Dictionary with parameter names as keys and sample arrays as values. Arrays may have shape MxNxP for P independent draws. @@ -1144,7 +1144,7 @@ def flatten_list_parameters( """ return_dict = {} for key, value in samples.items(): - if isinstance(value, np.ndarray) and value.ndim > 2: + if isinstance(value, (np.ndarray, Array)) and value.ndim > 2: num_dims = value.ndim - 2 indices = ( np.indices(value.shape[-num_dims:]).reshape(num_dims, -1).T diff --git a/src/dynode/vis_utils.py b/src/dynode/vis_utils.py index a443ec20..3d9b5fba 100644 --- a/src/dynode/vis_utils.py +++ b/src/dynode/vis_utils.py @@ -280,7 +280,7 @@ def plot_checkpoint_inference_correlation_pairs( Figure with n rows and n columns where n is the number of sampled parameters. """ # convert lists to np.arrays - posteriors: dict[str, np.ndarray] = flatten_list_parameters( + posteriors: dict[str, np.ndarray | Array] = flatten_list_parameters( { key: np.array(val) if isinstance(val, list) else val for key, val in posteriors_in.items() @@ -408,7 +408,7 @@ def plot_mcmc_chains( Matplotlib figure containing the plots. """ # Determine the number of parameters and chains - samples: dict[str, np.ndarray] = flatten_list_parameters( + samples: dict[str, np.ndarray | Array] = flatten_list_parameters( { key: np.array(val) if isinstance(val, list) else val for key, val in samples_in.items() diff --git a/tests/test_utils.py b/tests/test_utils.py index 32a548a4..24a45a74 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,7 @@ from enum import IntEnum import jax.numpy as jnp +import numpy as np import numpyro.distributions as dist from dynode import utils @@ -426,3 +427,57 @@ def test_get_timeline_from_solution_with_command_compartment_slice(): assert jnp.all( timeline == 16 ) # Each element in sol is 1, summed over 4*1*1*4 = 16 + + +def test_flatten_list_params_numpy(): + # simulate 4 chains and 20 samples each with 4 plated parameters + testing = {"test": np.ones((4, 20, 5))} + flattened = utils.flatten_list_parameters(testing) + assert "test" not in flattened.keys() + for suffix in range(5): + key = "test_%s" % str(suffix) + assert ( + key in flattened.keys() + ), "flatten_list_parameters not naming split params correctly." + assert flattened[key].shape == ( + 4, + 20, + ), "flatten_list_parameters breaking up wrong axis" + + +def test_flatten_list_params_jax_numpy(): + # simulate 4 chains and 20 samples each with 4 plated parameters + # this time with jax numpy instead of numpy + testing = {"test": jnp.ones((4, 20, 5))} + flattened = utils.flatten_list_parameters(testing) + assert "test" not in flattened.keys() + for suffix in range(5): + key = "test_%s" % str(suffix) + assert ( + key in flattened.keys() + ), "flatten_list_parameters not naming split params correctly." + assert flattened[key].shape == ( + 4, + 20, + ), "flatten_list_parameters breaking up wrong axis" + + +def test_flatten_list_params_multi_dim(): + # simulate 4 chains and 20 samples each with 10 plated parameters + # this time with jax numpy instead of numpy + testing = {"test": jnp.ones((4, 20, 5, 2))} + flattened = utils.flatten_list_parameters(testing) + assert "test" not in flattened.keys() + for suffix_first_dim in range(5): + for suffix_second_dim in range(2): + key = "test_%s_%s" % ( + str(suffix_first_dim), + str(suffix_second_dim), + ) + assert ( + key in flattened.keys() + ), "flatten_list_parameters not naming split params correctly." + assert flattened[key].shape == ( + 4, + 20, + ), "flatten_list_parameters breaking up wrong axis when passed >3"