Skip to content

Commit

Permalink
Added get_condensation_curve function in retrieval_util, added inc_mo…
Browse files Browse the repository at this point in the history
…del_name parameter in plot_spectrum, some minor maintenance
  • Loading branch information
tomasstolker committed Feb 2, 2023
1 parent 3771aef commit 10dff5a
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 179 deletions.
18 changes: 9 additions & 9 deletions species/data/companion_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"Males et al. 2014",
"Mirek Brandt et al. 2021",
"Stolker et al. 2019",
"Stolker et al. 2020"
"Stolker et al. 2020a"
]
},
"beta Pic c": {
Expand Down Expand Up @@ -151,7 +151,7 @@
"Cheetham et al. 2019",
"Gaia Early Data Release 3",
"Marleau et al. 2019",
"Stolker et al. 2020"
"Stolker et al. 2020a"
]
},
"51 Eri b": {
Expand Down Expand Up @@ -613,8 +613,7 @@
"Gaia Early Data Release 3",
"Haffert et al. 2019",
"Hashimoto et al. 2020",
"Stolker et al. 2020",
"Stolker et al. 2020.",
"Stolker et al. 2020b",
"Wang et al. 2020",
"Wang et al. 2021"
]
Expand Down Expand Up @@ -651,7 +650,7 @@
"references": [
"Gaia Early Data Release 3",
"Haffert et al. 2019",
"Stolker et al. 2020",
"Stolker et al. 2020b",
"Wang et al. 2020",
"Wang et al. 2021"
]
Expand Down Expand Up @@ -807,7 +806,7 @@
"Gaia Early Data Release 3",
"Grandjean et al. 2019",
"Milli et al. 2017",
"Stolker et al. 2020",
"Stolker et al. 2020a",
"Ward-Duong et al. 2020"
]
},
Expand Down Expand Up @@ -1065,7 +1064,7 @@
"Maire et al. 2015",
"Maire et al. 2016",
"Musso Barcucci et al. 2019",
"Stolker et al. 2020"
"Stolker et al. 2020a"
]
},
"kappa And b": {
Expand Down Expand Up @@ -2008,15 +2007,16 @@
0.0
],
"mass_companion": [
19.0,
17.0,
5.0
],
"accretion": true,
"references": [
"Gaia Early Data Release 3",
"Hartmann et al. 1998",
"Schmidt et al. 2008",
"Wu et al. 2015"
"Wu et al. 2015",
"Wu et al. 2020"
]
},
"SR 12 C": {
Expand Down
61 changes: 41 additions & 20 deletions species/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import os
import pathlib
import sys
# import urllib.error
import warnings

Expand Down Expand Up @@ -2632,10 +2633,10 @@ def get_pt_profiles(
self, tag: str, random: Optional[int] = None, out_file: Optional[str] = None
) -> Tuple[np.ndarray, np.ndarray]:
"""
Function for returning the pressure-temperature profiles from
the posterior of the atmospheric retrieval with
``petitRADTRANS``. The data can also optionally be written to
an output file.
Function for returning the pressure-temperature profiles
from the posterior of the atmospheric retrieval with
``petitRADTRANS``. The data can also optionally be
written to an output file.
Parameters
----------
Expand Down Expand Up @@ -2682,6 +2683,11 @@ def get_pt_profiles(
elif "nparam" in dset.attrs:
n_param = dset.attrs["nparam"]

if "temp_nodes" in dset.attrs:
temp_nodes = dset.attrs["temp_nodes"]
else:
temp_nodes = 15

samples = np.asarray(dset)

if random is None:
Expand Down Expand Up @@ -2752,15 +2758,15 @@ def get_pt_profiles(
else:
pt_smooth = 0.0

knot_press = np.logspace(np.log10(press[0]), np.log10(press[-1]), 15)
knot_press = np.logspace(np.log10(press[0]), np.log10(press[-1]), temp_nodes)

knot_temp = []
for j in range(15):
knot_temp.append(item[param_index[f"t{i}"]])
for k in range(temp_nodes):
knot_temp.append(item[param_index[f"t{k}"]])

knot_temp = np.asarray(knot_temp)

temp[:, j] = retrieval_util.pt_spline_interp(
temp[:, i] = retrieval_util.pt_spline_interp(
knot_press, knot_temp, press, pt_smooth
)

Expand Down Expand Up @@ -3132,8 +3138,10 @@ def add_retrieval(

if "temp_nodes" not in radtrans or radtrans["temp_nodes"] is None:
dset.attrs["temp_nodes"] = "None"
temp_nodes = 15
else:
dset.attrs["temp_nodes"] = radtrans["temp_nodes"]
temp_nodes = radtrans["temp_nodes"]

if "pt_smooth" in radtrans:
dset.attrs["pt_smooth"] = radtrans["pt_smooth"]
Expand Down Expand Up @@ -3165,10 +3173,10 @@ def add_retrieval(
print(" [DONE]")

print("Importing chemistry module...", end="", flush=True)
from poor_mans_nonequ_chem_FeH.poor_mans_nonequ_chem.poor_mans_nonequ_chem import (
interpol_abundances,
)

if "poor_mans_nonequ_chem" in sys.modules:
from poor_mans_nonequ_chem.poor_mans_nonequ_chem import interpol_abundances
else:
from petitRADTRANS.poor_mans_nonequ_chem.poor_mans_nonequ_chem import interpol_abundances
print(" [DONE]")

rt_object = Radtrans(
Expand Down Expand Up @@ -3221,11 +3229,11 @@ def add_retrieval(
or radtrans["pt_profile"] == "monotonic"
):
knot_press = np.logspace(
np.log10(pressure[0]), np.log10(pressure[-1]), 15
np.log10(pressure[0]), np.log10(pressure[-1]), temp_nodes
)

knot_temp = []
for k in range(15):
for k in range(temp_nodes):
knot_temp.append(sample_dict[f"t{k}"])

knot_temp = np.asarray(knot_temp)
Expand Down Expand Up @@ -3325,11 +3333,11 @@ def add_retrieval(
or radtrans["pt_profile"] == "monotonic"
):
knot_press = np.logspace(
np.log10(pressure[0]), np.log10(pressure[-1]), 15
np.log10(pressure[0]), np.log10(pressure[-1]), temp_nodes
)

knot_temp = []
for k in range(15):
for k in range(temp_nodes):
knot_temp.append(sample_dict[f"t{k}"])

knot_temp = np.asarray(knot_temp)
Expand Down Expand Up @@ -3862,6 +3870,16 @@ def petitcode_param(
else:
p_quench = None

if "temp_nodes" in sample_box.attributes:
temp_nodes = sample_box.attributes["temp_nodes"]
else:
temp_nodes = 15

if "pressure_grid" in sample_box.attributes:
pressure_grid = sample_box.attributes["pressure_grid"]
else:
pressure_grid = "smaller"

pressure = np.logspace(-6.0, 3.0, 180)

if sample_box.attributes["pt_profile"] == "molliere":
Expand All @@ -3876,10 +3894,10 @@ def petitcode_param(
)

else:
knot_press = np.logspace(np.log10(pressure[0]), np.log10(pressure[-1]), 15)
knot_press = np.logspace(np.log10(pressure[0]), np.log10(pressure[-1]), temp_nodes)

knot_temp = []
for i in range(15):
for i in range(temp_nodes):
knot_temp.append(model_param[f"t{i}"])

knot_temp = np.asarray(knot_temp)
Expand All @@ -3893,7 +3911,10 @@ def petitcode_param(
knot_press, knot_temp, pressure, pt_smooth=pt_smooth
)

from poor_mans_nonequ_chem.poor_mans_nonequ_chem import interpol_abundances
if "poor_mans_nonequ_chem" in sys.modules:
from poor_mans_nonequ_chem.poor_mans_nonequ_chem import interpol_abundances
else:
from petitRADTRANS.poor_mans_nonequ_chem.poor_mans_nonequ_chem import interpol_abundances

# Interpolate the abundances, following chemical equilibrium
abund_in = interpol_abundances(
Expand Down Expand Up @@ -3976,7 +3997,7 @@ def petitcode_param(
cloud_species=cloud_species,
scattering=True,
wavel_range=(0.5, 50.0),
pressure_grid=sample_box.attributes["pressure_grid"],
pressure_grid=pressure_grid,
res_mode="c-k",
cloud_wavel=cloud_wavel,
)
Expand Down
32 changes: 16 additions & 16 deletions species/plot/plot_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,7 @@ def plot_pt_profile(
if "poor_mans_nonequ_chem" in sys.modules:
from poor_mans_nonequ_chem.poor_mans_nonequ_chem import interpol_abundances
else:
from petitRADTRANS.poor_mans_nonequ_chem.poor_mans_nonequ_chem import (
interpol_abundances,
)
from petitRADTRANS.poor_mans_nonequ_chem.poor_mans_nonequ_chem import interpol_abundances

abund_in = interpol_abundances(
np.full(pressure.shape[0], median["c_o_ratio"]),
Expand Down Expand Up @@ -526,17 +524,19 @@ def plot_pt_profile(
pressure_grid=radtrans.pressure_grid,
)

for item in cloud_species:
if item in radtrans.cloud_species:
sat_press, sat_temp = retrieval_util.return_T_cond_Fe_comb(
median["metallicity"],
median["c_o_ratio"],
MMW=np.mean(abund_in["MMW"]),
)
for cloud_item in cloud_species:

if cloud_item in radtrans.cloud_species:
cond_temp = retrieval_util.get_condensation_curve(
composition=cloud_item[:-3],
press=pressure,
metallicity=median["metallicity"],
c_o_ratio=median["c_o_ratio"],
mmw=np.mean(abund_in["MMW"]))

ax.plot(
sat_temp,
sat_press,
cond_temp,
pressure,
"--",
lw=0.8,
color=next(color_iter, "black"),
Expand Down Expand Up @@ -721,12 +721,12 @@ def plot_pt_profile(

color_iter = iter(cloud_colors)

for item in cloud_species:
if item in radtrans.cloud_species:
cloud_index = radtrans.rt_object.cloud_species.index(item)
for cloud_item in cloud_species:
if cloud_item in radtrans.cloud_species:
cloud_index = radtrans.rt_object.cloud_species.index(cloud_item)

label = ""
for char in item[:-3]:
for char in cloud_item[:-3]:
if char.isnumeric():
label += f"$_{char}$"
else:
Expand Down
15 changes: 15 additions & 0 deletions species/plot/plot_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def plot_spectrum(
output: Optional[str] = "spectrum.pdf",
leg_param: Optional[List[str]] = None,
grid_hspace: float = 0.1,
inc_model_name: bool = False,
):
"""
Function for plotting a spectral energy distribution and combining
Expand Down Expand Up @@ -147,6 +148,9 @@ def plot_spectrum(
The relative height spacing between subplots, expressed
as a fraction of the average axis height. The default
value is set to 0.1.
inc_model_name : bool
Include the model name in the legend of any
:class:`~species.core.box.ModelBox`.
Returns
-------
Expand Down Expand Up @@ -467,6 +471,10 @@ def plot_spectrum(
if item not in leg_param:
del param[item]

if leg_param is not None:
param_new = {k: param[k] for k in leg_param}
param = param_new.copy()

par_key, par_unit, par_label = plot_util.quantity_unit(
param=list(param.keys()), object_type=object_type
)
Expand All @@ -475,6 +483,7 @@ def plot_spectrum(
# newline = False

for i, item in enumerate(par_key):

if item[:4] == "teff":
value = f"{param[item]:.0f}"

Expand Down Expand Up @@ -534,15 +543,21 @@ def plot_spectrum(
# label += '\n'
# newline = True

model_name = plot_util.model_name(box_item.model)

if par_unit[i] is None:
if len(label) > 0:
label += ", "
elif inc_model_name:
label += f"{model_name}: "

label += f"{par_label[i]} = {value}"

else:
if len(label) > 0:
label += ", "
elif inc_model_name:
label += f"{model_name}: "

label += f"{par_label[i]} = {value} {par_unit[i]}"

Expand Down
Loading

0 comments on commit 10dff5a

Please sign in to comment.