Skip to content

Commit

Permalink
Included uniform_priors and normal_priors to SamplesBox, added show_p…
Browse files Browse the repository at this point in the history
…riors parameter to plot_posterior
  • Loading branch information
tomasstolker committed Mar 8, 2024
1 parent c897514 commit e9e9387
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
6 changes: 6 additions & 0 deletions species/core/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ def __init__(self):
self.prob_sample = None
self.median_sample = None
self.attributes = None
self.uniform_priors = None
self.normal_priors = None


class SpectrumBox(Box):
Expand Down Expand Up @@ -555,6 +557,10 @@ def create_box(boxtype, **kwargs):
box.prob_sample = kwargs["prob_sample"]
box.median_sample = kwargs["median_sample"]
box.attributes = kwargs["attributes"]
if "uniform_priors" in kwargs:
box.uniform_priors = kwargs["uniform_priors"]
if "normal_priors" in kwargs:
box.normal_priors = kwargs["normal_priors"]

elif boxtype == "spectrum":
box = SpectrumBox()
Expand Down
11 changes: 11 additions & 0 deletions species/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2834,6 +2834,9 @@ def get_samples(
# Check if attributes are present for
# backward compatibility

uniform_priors = {}
normal_priors = {}

if "n_bounds" in attributes and attributes["n_bounds"] > 0:
dset_bounds = hdf5_file[f"results/fit/{tag}/bounds"]
print("\nUniform priors (min, max):")
Expand All @@ -2853,12 +2856,17 @@ def get_samples(
print(
f" - {bound_item}/{filter_item} = ({prior_bound[0]}, {prior_bound[1]})"
)
uniform_priors[f"{bound_item}/{filter_item}"] = (
prior_bound[0],
prior_bound[1],
)

else:
prior_bound = np.array(hdf5_file[group_path])
print(
f" - {bound_item} = ({prior_bound[0]}, {prior_bound[1]})"
)
uniform_priors[bound_item] = (prior_bound[0], prior_bound[1])

if "n_normal_prior" in attributes and attributes["n_normal_prior"] > 0:
dset_prior = hdf5_file[f"results/fit/{tag}/normal_prior"]
Expand All @@ -2868,6 +2876,7 @@ def get_samples(
group_path = f"results/fit/{tag}/normal_prior/{prior_item}"
norm_prior = np.array(hdf5_file[group_path])
print(f" - {prior_item} = ({norm_prior[0]}, {norm_prior[1]})")
normal_priors[prior_item] = (norm_prior[0], norm_prior[1])

median_sample = self.get_median_sample(tag, burnin, verbose=False)
prob_sample = self.get_probable_sample(tag, burnin, verbose=False)
Expand All @@ -2893,6 +2902,8 @@ def get_samples(
prob_sample=prob_sample,
median_sample=median_sample,
attributes=attributes,
uniform_priors=uniform_priors,
normal_priors=normal_priors,
)

@typechecked
Expand Down
33 changes: 31 additions & 2 deletions species/plot/plot_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typeguard import typechecked
from matplotlib.ticker import ScalarFormatter
from scipy.interpolate import RegularGridInterpolator
from scipy.stats import norm

from species.core import constants
from species.util.convert_util import logg_to_mass
Expand Down Expand Up @@ -218,6 +219,7 @@ def plot_posterior(
output: Optional[str] = None,
object_type: str = "planet",
param_inc: Optional[List[str]] = None,
show_priors: bool = False,
) -> mpl.figure.Figure:
"""
Function to plot the posterior distribution
Expand Down Expand Up @@ -283,6 +285,13 @@ def plot_posterior(
posterior plot. This parameter can also be used to change the
order of the parameters in the posterior plot. All parameters
will be included if the argument is set to ``None``.
show_priors : bool
Plot the normal priors in the diagonal panels together with the
1D marginalized posterior distributions. This will only show
the priors that had a normal distribution, so those that were
set with the ``normal_prior`` parameter in
:class:`~species.fit.fit_model.FitModel` and
:class:`~species.fit.retrieval.AtmosphericRetrieval.setup_retrieval`.
Returns
-------
Expand Down Expand Up @@ -729,6 +738,7 @@ def plot_posterior(

# Update axes labels

box_param = box.parameters.copy()
labels = update_labels(box.parameters, object_type=object_type)

# Check if parameter values were fixed
Expand Down Expand Up @@ -812,9 +822,28 @@ def plot_posterior(

for i in range(ndim):
for j in range(ndim):
if i >= j:
ax = axes[i, j]
ax = axes[i, j]

if show_priors and i == j and box_param[i] in box.normal_priors:
norm_param = box.normal_priors[box_param[i]]

x_norm = np.linspace(
norm_param[0] - 5.0 * norm_param[1],
norm_param[0] + 5.0 * norm_param[1],
200,
)

y_norm = norm.pdf(x_norm, norm_param[0], norm_param[1])

ax.plot(
x_norm,
0.9 * ax.get_ylim()[1] * y_norm / np.amax(y_norm),
ls=":",
lw=2.0,
color="dodgerblue",
)

if i >= j:
ax.xaxis.set_major_formatter(ScalarFormatter(useOffset=False))
ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))

Expand Down

0 comments on commit e9e9387

Please sign in to comment.