Skip to content

Commit

Permalink
Retrieving log_p_base with physical clouds instead of using the conde…
Browse files Browse the repository at this point in the history
…nsation profile
  • Loading branch information
tomasstolker committed Feb 8, 2024
1 parent 9e1abe1 commit d491cf0
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 23 deletions.
23 changes: 22 additions & 1 deletion species/fit/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,9 @@ def _set_parameters(
self.parameters.append("sigma_lnorm")

for item in self.cloud_species:
if "log_p_base" in bounds:
self.parameters.append(f"log_p_base_{item}")

cloud_lower = item[:-3].lower()

if f"{cloud_lower}_tau" in bounds:
Expand Down Expand Up @@ -1495,6 +1498,18 @@ def _prior_transform(

cube[cube_index["sigma_lnorm"]] = sigma_lnorm

if "log_p_base" in bounds:
for item in self.cloud_species:
# Use the same cloud base pressure range
# for the different cloud species
log_p_base = (
bounds["log_p_base"][0]
+ (bounds["log_p_base"][1] - bounds["log_p_base"][0])
* cube[cube_index[f"log_p_base_{item}"]]
)

cube[cube_index[f"log_p_base_{item}"]] = log_p_base

if "log_tau_cloud" in bounds:
log_tau_cloud = (
bounds["log_tau_cloud"][0]
Expand Down Expand Up @@ -2154,9 +2169,16 @@ def _lnlike(
for item in cloud_param:
if item in self.parameters:
cloud_dict[item] = cube[cube_index[item]]

# elif item in ['log_kzz', 'sigma_lnorm']:
# cloud_dict[item] = None

for item in self.cloud_species:
if f"log_p_base_{item}" in self.parameters:
cloud_dict[f"log_p_base_{item}"] = cube[
cube_index[f"log_p_base_{item}"]
]

elif "fsed_1" in self.parameters and "fsed_2" in self.parameters:
cloud_param_1 = [
"fsed_1",
Expand Down Expand Up @@ -3931,7 +3953,6 @@ def run_dynesty(
],
ptform_args=[self.bounds, self.cube_index],
) as pool:

print(f"Initialized a dynesty.pool with {n_pool} workers")

if dynamic:
Expand Down
6 changes: 5 additions & 1 deletion species/util/plot_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,13 +581,17 @@ def update_labels(param: List[str], object_type: str = "planet") -> List[str]:
if f"{item}_fraction" in param:
index = param.index(f"{item}_fraction")
param[index] = (
rf"$\log\,\tilde{{\mathrm{{X}}}}" rf"_\mathrm{{{cloud_labels[i]}}}$"
rf"$\log\,\tilde{{\mathrm{{X}}}}_\mathrm{{{cloud_labels[i]}}}$"
)

if f"{item}_tau" in param:
index = param.index(f"{item}_tau")
param[index] = rf"$\bar{{\tau}}_\mathrm{{{cloud_labels[i]}}}$"

if f"log_p_base_{item}" in param:
index = param.index(f"log_p_base_{item}")
param[index] = rf"$\log\,P_\mathrm{{{cloud_labels[i]}}}$"

for i, item_i in enumerate(cloud_species):
for j, item_j in enumerate(cloud_species):
if f"{item_i}_{item_j}_ratio" in param:
Expand Down
49 changes: 28 additions & 21 deletions species/util/retrieval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,29 +488,31 @@ def create_pt_profile(
)

elif pt_profile == "gradient":
num_layer = 6 # could make a variable in the future
num_layer = 6 # could make a variable in the future
layer_pt_slopes = np.ones(num_layer) * np.nan
for index in range(num_layer):
layer_pt_slopes[index] = cube[cube_index[f'PTslope_{num_layer - index}']]
layer_pt_slopes[index] = cube[cube_index[f"PTslope_{num_layer - index}"]]

try:
from petitRADTRANS.physics import dTdP_temperature_profile
except ImportError as e:
raise ImportError(
"""Can\'t import the dTdP profile function from petitRADTRANS,
check that your version of pRT includes this function in
petitRADTRANS.physics""", e
petitRADTRANS.physics""",
e,
)

temp = dTdP_temperature_profile(pressure,
num_layer, # could change in the future
layer_pt_slopes,
cube[cube_index["T_bottom"]])

temp = dTdP_temperature_profile(
pressure,
num_layer, # could change in the future
layer_pt_slopes,
cube[cube_index["T_bottom"]],
)

phot_press = None
conv_press = None


elif pt_profile in ["free", "monotonic"]:
knot_temp = []
for i in range(knot_press.shape[0]):
Expand Down Expand Up @@ -1188,15 +1190,22 @@ def calc_spectrum_clouds(
p_base = {}

for cloud_item in log_x_base:
p_base_item = find_cloud_deck(
cloud_item,
pressure,
temperature,
metallicity,
c_o_ratio,
mmw=np.mean(mmw),
plotting=plotting,
)
if f"log_p_base_{cloud_item}(c)" in cloud_dict:
p_base_item = 10.0 ** cloud_dict[f"log_p_base_{cloud_item}(c)"]
p_base[f"{cloud_item}(c)"] = p_base_item

else:
p_base_item = find_cloud_deck(
cloud_item,
pressure,
temperature,
metallicity,
c_o_ratio,
mmw=np.mean(mmw),
plotting=plotting,
)

p_base[f"{cloud_item}(c)"] = p_base_item

abund_in[f"{cloud_item}(c)"] = np.zeros_like(temperature)

Expand All @@ -1206,8 +1215,6 @@ def calc_spectrum_clouds(
** cloud_dict["fsed"]
)

p_base[f"{cloud_item}(c)"] = p_base_item

# Adaptive pressure refinement around the cloud base

if pressure_grid == "clouds":
Expand Down Expand Up @@ -3053,4 +3060,4 @@ def convective_flux(

f_conv[np.isnan(f_conv)] = 0.0

return f_conv # (W m-2)
return f_conv # (W m-2)

0 comments on commit d491cf0

Please sign in to comment.