From e9e9387b79d5fa9a72e82dd740b3926dad97e7a8 Mon Sep 17 00:00:00 2001 From: Tomas Stolker Date: Fri, 8 Mar 2024 17:12:13 +0100 Subject: [PATCH] Included uniform_priors and normal_priors to SamplesBox, added show_priors parameter to plot_posterior --- species/core/box.py | 6 ++++++ species/data/database.py | 11 +++++++++++ species/plot/plot_mcmc.py | 33 +++++++++++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/species/core/box.py b/species/core/box.py index ebc3129e..d3b8bd2a 100644 --- a/species/core/box.py +++ b/species/core/box.py @@ -366,6 +366,8 @@ def __init__(self): self.prob_sample = None self.median_sample = None self.attributes = None + self.uniform_priors = None + self.normal_priors = None class SpectrumBox(Box): @@ -555,6 +557,10 @@ def create_box(boxtype, **kwargs): box.prob_sample = kwargs["prob_sample"] box.median_sample = kwargs["median_sample"] box.attributes = kwargs["attributes"] + if "uniform_priors" in kwargs: + box.uniform_priors = kwargs["uniform_priors"] + if "normal_priors" in kwargs: + box.normal_priors = kwargs["normal_priors"] elif boxtype == "spectrum": box = SpectrumBox() diff --git a/species/data/database.py b/species/data/database.py index 812084c7..fbee3589 100644 --- a/species/data/database.py +++ b/species/data/database.py @@ -2834,6 +2834,9 @@ def get_samples( # Check if attributes are present for # backward compatibility + uniform_priors = {} + normal_priors = {} + if "n_bounds" in attributes and attributes["n_bounds"] > 0: dset_bounds = hdf5_file[f"results/fit/{tag}/bounds"] print("\nUniform priors (min, max):") @@ -2853,12 +2856,17 @@ def get_samples( print( f" - {bound_item}/{filter_item} = ({prior_bound[0]}, {prior_bound[1]})" ) + uniform_priors[f"{bound_item}/{filter_item}"] = ( + prior_bound[0], + prior_bound[1], + ) else: prior_bound = np.array(hdf5_file[group_path]) print( f" - {bound_item} = ({prior_bound[0]}, {prior_bound[1]})" ) + uniform_priors[bound_item] = (prior_bound[0], prior_bound[1]) if "n_normal_prior" in attributes and attributes["n_normal_prior"] > 0: dset_prior = hdf5_file[f"results/fit/{tag}/normal_prior"] @@ -2868,6 +2876,7 @@ def get_samples( group_path = f"results/fit/{tag}/normal_prior/{prior_item}" norm_prior = np.array(hdf5_file[group_path]) print(f" - {prior_item} = ({norm_prior[0]}, {norm_prior[1]})") + normal_priors[prior_item] = (norm_prior[0], norm_prior[1]) median_sample = self.get_median_sample(tag, burnin, verbose=False) prob_sample = self.get_probable_sample(tag, burnin, verbose=False) @@ -2893,6 +2902,8 @@ def get_samples( prob_sample=prob_sample, median_sample=median_sample, attributes=attributes, + uniform_priors=uniform_priors, + normal_priors=normal_priors, ) @typechecked diff --git a/species/plot/plot_mcmc.py b/species/plot/plot_mcmc.py index 878e51ab..27024a4f 100644 --- a/species/plot/plot_mcmc.py +++ b/species/plot/plot_mcmc.py @@ -15,6 +15,7 @@ from typeguard import typechecked from matplotlib.ticker import ScalarFormatter from scipy.interpolate import RegularGridInterpolator +from scipy.stats import norm from species.core import constants from species.util.convert_util import logg_to_mass @@ -218,6 +219,7 @@ def plot_posterior( output: Optional[str] = None, object_type: str = "planet", param_inc: Optional[List[str]] = None, + show_priors: bool = False, ) -> mpl.figure.Figure: """ Function to plot the posterior distribution @@ -283,6 +285,13 @@ def plot_posterior( posterior plot. This parameter can also be used to change the order of the parameters in the posterior plot. All parameters will be included if the argument is set to ``None``. + show_priors : bool + Plot the normal priors in the diagonal panels together with the + 1D marginalized posterior distributions. This will only show + the priors that had a normal distribution, so those that were + set with the ``normal_prior`` parameter in + :class:`~species.fit.fit_model.FitModel` and + :class:`~species.fit.retrieval.AtmosphericRetrieval.setup_retrieval`. Returns ------- @@ -729,6 +738,7 @@ def plot_posterior( # Update axes labels + box_param = box.parameters.copy() labels = update_labels(box.parameters, object_type=object_type) # Check if parameter values were fixed @@ -812,9 +822,28 @@ def plot_posterior( for i in range(ndim): for j in range(ndim): - if i >= j: - ax = axes[i, j] + ax = axes[i, j] + + if show_priors and i == j and box_param[i] in box.normal_priors: + norm_param = box.normal_priors[box_param[i]] + + x_norm = np.linspace( + norm_param[0] - 5.0 * norm_param[1], + norm_param[0] + 5.0 * norm_param[1], + 200, + ) + y_norm = norm.pdf(x_norm, norm_param[0], norm_param[1]) + + ax.plot( + x_norm, + 0.9 * ax.get_ylim()[1] * y_norm / np.amax(y_norm), + ls=":", + lw=2.0, + color="dodgerblue", + ) + + if i >= j: ax.xaxis.set_major_formatter(ScalarFormatter(useOffset=False)) ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))