Skip to content

Commit

Permalink
Finished implementing support for the abund_nodes parameter with free…
Browse files Browse the repository at this point in the history
… retrievals, added inc_abund parameter to plot_posterior, capital sensitive units in convert_units function
  • Loading branch information
tomasstolker committed Jul 14, 2023
1 parent b61760c commit 9695761
Show file tree
Hide file tree
Showing 8 changed files with 631 additions and 433 deletions.
216 changes: 141 additions & 75 deletions species/analysis/retrieval.py

Large diffs are not rendered by default.

64 changes: 48 additions & 16 deletions species/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,13 +1081,13 @@ def add_object(
if verbose:
print(f" - {mag_item}:")

print(f" - Mean wavelength (um) = {mean_wavel:.4e}")

print(
f" - Apparent magnitude = {app_mag[mag_item][0]:.2f} +/- "
f"{app_mag[mag_item][1]:.2f}"
)

print(f" - Mean wavelength (um) = {mean_wavel:.4e}")

print(
f" - Flux (W m-2 um-1) = {flux[mag_item]:.2e} +/- "
f"{error[mag_item]:.2e}"
Expand Down Expand Up @@ -1124,13 +1124,13 @@ def add_object(
app_mag_item = (dered_mag, app_mag[mag_item][i][1])

if verbose:
print(f" - Mean wavelength (um) = {mean_wavel:.4e}")

print(
f" - Apparent magnitude = {app_mag_item[0]:.2f} +/- "
f"{app_mag_item[1]:.2f}"
)

print(f" - Mean wavelength (um) = {mean_wavel:.4e}")

print(
f" - Flux (W m-2 um-1) = {flux[mag_item][i]:.2e} +/- "
f"{error[mag_item][i]:.2e}"
Expand Down Expand Up @@ -1184,14 +1184,18 @@ def add_object(

if flux_item in units:
flux_in = np.array([[mean_wavel, data[2], data[3]]])
flux_out = data_util.convert_units(flux_in, ('um', units[flux_item]))
flux_out = data_util.convert_units(
flux_in, ("um", units[flux_item])
)

data = [np.nan, np.nan, flux_out[0, 1], flux_out[0, 2]]

if verbose:
print(f" - {flux_item}:")
print(f" - Mean wavelength (um) = {mean_wavel:.4e}")
print(f" - Flux (W m-2 um-1) = {data[2]:.2e} +/- {data[3]:.2e}")
print(
f" - Flux (W m-2 um-1) = {data[2]:.2e} +/- {data[3]:.2e}"
)

# None, None, (W m-2 um-1), (W m-2 um-1)
dset = h5_file.create_dataset(
Expand Down Expand Up @@ -1230,7 +1234,9 @@ def add_object(
covariance = hdulist[1].data["COVARIANCE"] # (W m-2 um-1)^2
error = np.sqrt(np.diag(covariance)) # (W m-2 um-1)

read_spec[spec_item] = np.column_stack([wavelength, flux, error])
read_spec[spec_item] = np.column_stack(
[wavelength, flux, error]
)

else:
# Otherwise try to read a 2D dataset with 3 columns
Expand All @@ -1246,7 +1252,9 @@ def add_object(
and spec_item not in read_spec
):
if spec_item in units:
data = data_util.convert_units(data, units[spec_item])
data = data_util.convert_units(
data, units[spec_item]
)

read_spec[spec_item] = data

Expand Down Expand Up @@ -1294,7 +1302,10 @@ def add_object(
)
read_spec[spec_item][:, 1] *= 10.0 ** (0.4 * ext_mag)

if read_spec[spec_item].shape[0] == 3 and read_spec[spec_item].shape[1] != 3:
if (
read_spec[spec_item].shape[0] == 3
and read_spec[spec_item].shape[1] != 3
):
warnings.warn(
f"Transposing the data of {spec_item} because "
f"the first instead of the second axis "
Expand Down Expand Up @@ -1363,12 +1374,14 @@ def add_object(

else:
if spec_item in units:
warnings.warn("The unit conversion has not been "
"implemented for covariance matrices. "
"Please open an issue on the Github "
"page if such functionality is required "
"or provide the file with covariances "
"in (W m-2 um-1)^2.")
warnings.warn(
"The unit conversion has not been "
"implemented for covariance matrices. "
"Please open an issue on the Github "
"page if such functionality is required "
"or provide the file with covariances "
"in (W m-2 um-1)^2."
)

# Otherwise try to read a square, 2D dataset
if verbose:
Expand All @@ -1385,7 +1398,10 @@ def add_object(

if data.ndim == 2 and data.shape[0] == data.shape[1]:
if spec_item not in read_cov:
if data.shape[0] == read_spec[spec_item].shape[0]:
if (
data.shape[0]
== read_spec[spec_item].shape[0]
):
if np.all(np.diag(data) == 1.0):
warnings.warn(corr_warn)

Expand Down Expand Up @@ -3309,6 +3325,11 @@ def add_retrieval(
if "max_press" in radtrans:
dset.attrs["max_press"] = radtrans["max_press"]

if "abund_nodes" in radtrans:
dset.attrs["abund_nodes"] = radtrans["abund_nodes"]
else:
dset.attrs["abund_nodes"] = "None"

print(" [DONE]")

# Set number of pressures
Expand Down Expand Up @@ -3675,6 +3696,16 @@ def get_retrieval_spectra(
n_line_species = dset.attrs["n_line_species"]
n_cloud_species = dset.attrs["n_cloud_species"]

# Get number of abundance nodes

if "abund_nodes" in dset.attrs:
if dset.attrs["abund_nodes"] == "None":
abund_nodes = None
else:
abund_nodes = dset.attrs["abund_nodes"]
else:
abund_nodes = None

# Convert numpy boolean to regular boolean

scattering = bool(dset.attrs["scattering"])
Expand Down Expand Up @@ -3816,6 +3847,7 @@ def get_retrieval_spectra(
distance=distance,
pt_smooth=pt_smooth,
temp_nodes=temp_nodes,
abund_nodes=abund_nodes,
read_rad=read_rad,
sample=item,
)
Expand Down
Loading

0 comments on commit 9695761

Please sign in to comment.